mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add device generalisation support for distributed tests (#152471)
### MOTIVATION To generalize Distributed test cases for non-CUDA devices ### CHANGES - test/distributed/optim/test_zero_redundancy_optimizer.py - test/distributed/test_c10d_logger.py - test/distributed/test_compute_comm_reordering.py Replaced hard coded device names with get_devtype from torch.testing._internal.common_fsdp. DistributedTestBase is used instead of MultiProcessTestCase, to make use of helper functions. - torch/testing/_internal/common_distributed.py extended common utility functions Pull Request resolved: https://github.com/pytorch/pytorch/pull/152471 Approved by: https://github.com/d4l3k
This commit is contained in:
committed by
PyTorch MergeBot
parent
0aed855b2b
commit
e1f28fe17b
@ -6,7 +6,6 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@ -30,12 +29,23 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
|
|||||||
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
|
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import AdamW, SGD
|
from torch.optim import AdamW, SGD
|
||||||
from torch.testing._internal import common_distributed
|
from torch.testing._internal.common_distributed import (
|
||||||
|
DistributedTestBase,
|
||||||
|
logger,
|
||||||
|
requires_accelerator_dist_backend,
|
||||||
|
requires_ddp_rank,
|
||||||
|
requires_gloo,
|
||||||
|
skip_if_lt_x_gpu,
|
||||||
|
skip_if_no_gpu,
|
||||||
|
skip_if_rocm_multiprocess,
|
||||||
|
skip_if_win32,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
IS_WINDOWS,
|
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
|
skipIfHpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -47,63 +57,24 @@ except ImportError:
|
|||||||
HAS_TORCHVISION = False
|
HAS_TORCHVISION = False
|
||||||
|
|
||||||
|
|
||||||
# Use GLOO on GPU when running CUDA + Windows
|
device_type = str(get_devtype())
|
||||||
def _get_backend_for_tests():
|
|
||||||
return (
|
|
||||||
dist.Backend.NCCL
|
|
||||||
if not IS_WINDOWS and torch.cuda.is_available()
|
|
||||||
# Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
|
|
||||||
# no GPUs are available.
|
|
||||||
else dist.Backend.GLOO
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BACKEND = _get_backend_for_tests()
|
class TestZeroRedundancyOptimizer(DistributedTestBase):
|
||||||
|
|
||||||
|
|
||||||
class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
||||||
self._spawn_processes()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return (
|
return device_type
|
||||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
try:
|
|
||||||
torch.distributed.destroy_process_group()
|
|
||||||
except AssertionError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
os.remove(self.file_name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def dist_init(self, rank, world_size=-1, backend=BACKEND):
|
|
||||||
if world_size < 1:
|
|
||||||
world_size = self.world_size
|
|
||||||
store = dist.FileStore(self.file_name, world_size)
|
|
||||||
return dist.init_process_group(
|
|
||||||
backend=backend,
|
|
||||||
store=store,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||||
def test_state_dict(self):
|
def test_state_dict(self):
|
||||||
"""Check that ZeroRedundancyOptimizer exposes the expected state dict
|
"""Check that ZeroRedundancyOptimizer exposes the expected state dict
|
||||||
interface, irrespective of the sharding."""
|
interface, irrespective of the sharding."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR1 = 0.1
|
LR1 = 0.1
|
||||||
LR2 = 0.01
|
LR2 = 0.01
|
||||||
MOMENTUM = 0.9
|
MOMENTUM = 0.9
|
||||||
@ -171,7 +142,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
def test_lr_scheduler(self):
|
def test_lr_scheduler(self):
|
||||||
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
|
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
|
||||||
ZeroRedundancyOptimizer."""
|
ZeroRedundancyOptimizer."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
NUM_ITERS = 5
|
NUM_ITERS = 5
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
x = torch.tensor([1.0], device=self.device, requires_grad=True)
|
x = torch.tensor([1.0], device=self.device, requires_grad=True)
|
||||||
@ -193,7 +164,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
|
|
||||||
def test_step_with_kwargs(self):
|
def test_step_with_kwargs(self):
|
||||||
"""Check that the ``step(**kwargs)`` interface is properly exposed."""
|
"""Check that the ``step(**kwargs)`` interface is properly exposed."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.1
|
LR = 0.1
|
||||||
|
|
||||||
class SGDWithStepKWArg(torch.optim.SGD):
|
class SGDWithStepKWArg(torch.optim.SGD):
|
||||||
@ -217,7 +188,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
"""Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
|
"""Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
|
||||||
extra keys to ``param_groups`` exposes those keys through ZeRO's own
|
extra keys to ``param_groups`` exposes those keys through ZeRO's own
|
||||||
``param_groups``."""
|
``param_groups``."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.1
|
LR = 0.1
|
||||||
|
|
||||||
class SGDWithNewKey(torch.optim.SGD):
|
class SGDWithNewKey(torch.optim.SGD):
|
||||||
@ -236,7 +207,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
def test_step_without_closure(self):
|
def test_step_without_closure(self):
|
||||||
"""Check that the ``step()`` method (without closure) is handled as
|
"""Check that the ``step()`` method (without closure) is handled as
|
||||||
expected."""
|
expected."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.1
|
LR = 0.1
|
||||||
|
|
||||||
class SGDWithoutClosure(torch.optim.SGD):
|
class SGDWithoutClosure(torch.optim.SGD):
|
||||||
@ -255,7 +226,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
|
|
||||||
def test_zero_grad(self):
|
def test_zero_grad(self):
|
||||||
"""Check that the ``zero_grad`` method is properly handled."""
|
"""Check that the ``zero_grad`` method is properly handled."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
x = torch.rand(1)
|
x = torch.rand(1)
|
||||||
m = torch.nn.Linear(1, 1)
|
m = torch.nn.Linear(1, 1)
|
||||||
@ -271,7 +242,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
|
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
|
||||||
passing different values for the ``params`` argument."""
|
passing different values for the ``params`` argument."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
m = torch.nn.Sequential(
|
m = torch.nn.Sequential(
|
||||||
torch.nn.Linear(5, 10),
|
torch.nn.Linear(5, 10),
|
||||||
@ -336,7 +307,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
NOTE: This test should be removed once support for sparse parameters
|
NOTE: This test should be removed once support for sparse parameters
|
||||||
and varying parameter types is added.
|
and varying parameter types is added.
|
||||||
"""
|
"""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
inputs = [
|
inputs = [
|
||||||
[torch.sparse_coo_tensor(size=(2, 3))],
|
[torch.sparse_coo_tensor(size=(2, 3))],
|
||||||
@ -353,25 +324,16 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||||||
|
|
||||||
|
|
||||||
class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return (
|
|
||||||
torch.device(self.rank)
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
return min(4, max(2, torch.cuda.device_count()))
|
return min(4, max(2, torch.get_device_module(self.device).device_count()))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def context(self):
|
def context(self):
|
||||||
return (
|
if requires_ddp_rank(self.device):
|
||||||
nullcontext()
|
return torch.get_device_module(self.device).device(self.rank)
|
||||||
if not torch.cuda.is_available()
|
else:
|
||||||
else torch.cuda.device(self.rank)
|
return nullcontext()
|
||||||
)
|
|
||||||
|
|
||||||
def _check_same_model_params(
|
def _check_same_model_params(
|
||||||
self,
|
self,
|
||||||
@ -396,12 +358,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
|
msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
|
||||||
)
|
)
|
||||||
|
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
@common_distributed.skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
def test_step(self):
|
def test_step(self):
|
||||||
"""Check that ZeroRedundancyOptimizer properly exposes the ``step()``
|
"""Check that ZeroRedundancyOptimizer properly exposes the ``step()``
|
||||||
interface."""
|
interface."""
|
||||||
self.dist_init(self.rank, world_size=self.world_size)
|
self.create_pg(self.device)
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
|
|
||||||
with self.context:
|
with self.context:
|
||||||
@ -436,13 +398,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
self.assertEqual(m.weight, m_zero.weight)
|
self.assertEqual(m.weight, m_zero.weight)
|
||||||
self.assertEqual(m.bias, m_zero.bias)
|
self.assertEqual(m.bias, m_zero.bias)
|
||||||
|
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
@common_distributed.skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
def test_step_with_closure(self):
|
def test_step_with_closure(self):
|
||||||
"""Check that ZeroRedundancyOptimizer properly exposes the
|
"""Check that ZeroRedundancyOptimizer properly exposes the
|
||||||
``step(closure)`` interface."""
|
``step(closure)`` interface."""
|
||||||
self.dist_init(self.rank, world_size=self.world_size)
|
self.create_pg(self.device)
|
||||||
|
|
||||||
with self.context:
|
with self.context:
|
||||||
for bucket_view in [False, True]:
|
for bucket_view in [False, True]:
|
||||||
x_val = self.rank + 1
|
x_val = self.rank + 1
|
||||||
@ -487,11 +448,11 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
self.assertEqual(m.weight, torch.tensor([[1.1]]))
|
self.assertEqual(m.weight, torch.tensor([[1.1]]))
|
||||||
self.assertEqual(m.bias, torch.tensor([2.1]))
|
self.assertEqual(m.bias, torch.tensor([2.1]))
|
||||||
|
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
def test_lr_scheduler(self):
|
def test_lr_scheduler(self):
|
||||||
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
|
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
|
||||||
ZeroRedundancyOptimizer."""
|
ZeroRedundancyOptimizer."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
x = torch.tensor([1.0], device=self.device, requires_grad=True)
|
x = torch.tensor([1.0], device=self.device, requires_grad=True)
|
||||||
x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
|
x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
|
||||||
o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
|
o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
|
||||||
@ -519,7 +480,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
``ZeroRedundancyOptimizer._partition_parameters()`` in
|
``ZeroRedundancyOptimizer._partition_parameters()`` in
|
||||||
zero_redundancy_optimizer.py.
|
zero_redundancy_optimizer.py.
|
||||||
"""
|
"""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
sizes = [9, 7, 5, 3]
|
sizes = [9, 7, 5, 3]
|
||||||
params = []
|
params = []
|
||||||
@ -541,7 +502,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
``ZeroRedundancyOptimizer._partition_parameters()`` in
|
``ZeroRedundancyOptimizer._partition_parameters()`` in
|
||||||
zero_redundancy_optimizer.py.
|
zero_redundancy_optimizer.py.
|
||||||
"""
|
"""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
|
|
||||||
# Test with all parameters trainable to begin with
|
# Test with all parameters trainable to begin with
|
||||||
@ -589,14 +550,14 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
all_trainable()
|
all_trainable()
|
||||||
some_trainable()
|
some_trainable()
|
||||||
|
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
def test_multiple_param_groups(self):
|
def test_multiple_param_groups(self):
|
||||||
"""
|
"""
|
||||||
Check parity between constructing ZeRO with multiple parameter groups
|
Check parity between constructing ZeRO with multiple parameter groups
|
||||||
upfront versus adding parameter groups to ZeRO after construction
|
upfront versus adding parameter groups to ZeRO after construction
|
||||||
versus a non-sharded optimizer.
|
versus a non-sharded optimizer.
|
||||||
"""
|
"""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
BATCH_SIZE, NUM_ITERS = 8, 3
|
BATCH_SIZE, NUM_ITERS = 8, 3
|
||||||
INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
|
INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
|
||||||
WD, LR = 0.01, 0.01
|
WD, LR = 0.01, 0.01
|
||||||
@ -656,12 +617,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
torch.testing.assert_close(layer1.bias, layer2.bias)
|
torch.testing.assert_close(layer1.bias, layer2.bias)
|
||||||
torch.testing.assert_close(layer1.bias, layer3.bias)
|
torch.testing.assert_close(layer1.bias, layer3.bias)
|
||||||
|
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
@common_distributed.skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
def test_collect_shards(self):
|
def test_collect_shards(self):
|
||||||
"""Check the state consolidation mechanism and the state dict exposed
|
"""Check the state consolidation mechanism and the state dict exposed
|
||||||
by ZeroRedundancyOptimizer."""
|
by ZeroRedundancyOptimizer."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
LR = 1e-3
|
LR = 1e-3
|
||||||
MOMENTUM = 0.99
|
MOMENTUM = 0.99
|
||||||
BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
|
BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
|
||||||
@ -719,27 +680,25 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
# trivial
|
# trivial
|
||||||
MIN_WORLD_SIZE = 4
|
MIN_WORLD_SIZE = 4
|
||||||
if self.world_size < MIN_WORLD_SIZE:
|
if self.world_size < MIN_WORLD_SIZE:
|
||||||
common_distributed.logger.info(
|
logger.info(
|
||||||
"Skipping `test_nondefault_process_group()` since world size "
|
"Skipping `test_nondefault_process_group()` since world size "
|
||||||
"of %s is less than %s",
|
"of %s is less than %s",
|
||||||
self.world_size,
|
self.world_size,
|
||||||
MIN_WORLD_SIZE,
|
MIN_WORLD_SIZE,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
BACKEND = dist.Backend.GLOO
|
# Use GPU if enough are available, or fall back to CPU otherwise
|
||||||
self.dist_init(self.rank, self.world_size, BACKEND)
|
if torch.get_device_module(self.device).device_count() < self.world_size:
|
||||||
# Use GPU if enough are available, or fall back to CPU otherwise, which
|
|
||||||
# is fine since Gloo backend supports both
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
|
|
||||||
device = torch.device(self.rank)
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device(self.device)
|
||||||
|
self.create_pg(device.type)
|
||||||
# Create a new process group consisting of the even ranks to exercise
|
# Create a new process group consisting of the even ranks to exercise
|
||||||
# the case where the global and local ranks do not necessarily match
|
# the case where the global and local ranks do not necessarily match
|
||||||
subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
|
subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
|
||||||
process_group = dist.new_group(
|
process_group = dist.new_group(
|
||||||
ranks=subgroup_ranks,
|
ranks=subgroup_ranks,
|
||||||
backend=BACKEND,
|
backend=self.backend(device.type),
|
||||||
)
|
)
|
||||||
# Ranks not participating in the new process group are no longer needed
|
# Ranks not participating in the new process group are no longer needed
|
||||||
if self.rank not in subgroup_ranks:
|
if self.rank not in subgroup_ranks:
|
||||||
@ -811,7 +770,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
)
|
)
|
||||||
check(optimizer)
|
check(optimizer)
|
||||||
|
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"optimizer_class_str",
|
"optimizer_class_str",
|
||||||
["Adam", "AdamW", "SGD"],
|
["Adam", "AdamW", "SGD"],
|
||||||
@ -828,7 +787,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
):
|
):
|
||||||
"""When combined with DDP, check that a local optimizer gives the same
|
"""When combined with DDP, check that a local optimizer gives the same
|
||||||
results as wrapping that optimizer with ZeroRedundancyOptimizer."""
|
results as wrapping that optimizer with ZeroRedundancyOptimizer."""
|
||||||
self.dist_init(self.rank)
|
self.create_pg(self.device)
|
||||||
BATCHES = 20
|
BATCHES = 20
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
LR = 1e-3
|
LR = 1e-3
|
||||||
@ -867,7 +826,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
)
|
)
|
||||||
sharded_ddp_model = DDP(
|
sharded_ddp_model = DDP(
|
||||||
module=model,
|
module=model,
|
||||||
device_ids=[self.rank],
|
device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
|
||||||
broadcast_buffers=True,
|
broadcast_buffers=True,
|
||||||
find_unused_parameters=True,
|
find_unused_parameters=True,
|
||||||
)
|
)
|
||||||
@ -879,7 +838,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
)
|
)
|
||||||
ddp_model = DDP(
|
ddp_model = DDP(
|
||||||
local_model,
|
local_model,
|
||||||
device_ids=[self.rank],
|
device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
|
||||||
broadcast_buffers=True,
|
broadcast_buffers=True,
|
||||||
find_unused_parameters=True,
|
find_unused_parameters=True,
|
||||||
)
|
)
|
||||||
@ -892,7 +851,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def check_step():
|
def check_step():
|
||||||
input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM))
|
input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)).to(self.device)
|
||||||
|
|
||||||
def closure_ddp(input_tensor=input_tensor):
|
def closure_ddp(input_tensor=input_tensor):
|
||||||
ddp_optimizer.zero_grad()
|
ddp_optimizer.zero_grad()
|
||||||
@ -970,13 +929,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
NUM_EPOCHS = 2
|
NUM_EPOCHS = 2
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.cuda.manual_seed(0)
|
if "cpu" not in device:
|
||||||
|
torch.get_device_module(device).manual_seed(0)
|
||||||
|
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
world_size = self.world_size
|
world_size = self.world_size
|
||||||
is_gpu = device.type == "cuda"
|
self.create_pg(device)
|
||||||
backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
|
|
||||||
self.dist_init(rank, world_size, backend)
|
|
||||||
|
|
||||||
model = torch.nn.Sequential(
|
model = torch.nn.Sequential(
|
||||||
torch.nn.Linear(2, 3),
|
torch.nn.Linear(2, 3),
|
||||||
@ -988,7 +946,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
# DDP ensures correct gradients in data parallel training, so DDP with
|
# DDP ensures correct gradients in data parallel training, so DDP with
|
||||||
# local optimizers on uneven inputs should be equivalent to ZeRO on
|
# local optimizers on uneven inputs should be equivalent to ZeRO on
|
||||||
# uneven inputs with gradients being manually set
|
# uneven inputs with gradients being manually set
|
||||||
ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
|
ddp_model = (
|
||||||
|
DDP(model, device_ids=[rank]) if requires_ddp_rank(device) else DDP(model)
|
||||||
|
)
|
||||||
local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
|
local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
|
||||||
zero_model = copy.deepcopy(model)
|
zero_model = copy.deepcopy(model)
|
||||||
zero_model.to(device)
|
zero_model.to(device)
|
||||||
@ -1111,27 +1071,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
)
|
)
|
||||||
iter += 1
|
iter += 1
|
||||||
|
|
||||||
@common_distributed.requires_nccl()
|
@requires_accelerator_dist_backend()
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
def test_zero_join_gpu(self):
|
def test_zero_join_gpu(self):
|
||||||
"""Check that the ZeRO join hook allows training with uneven inputs
|
"""Check that the ZeRO join hook allows training with uneven inputs
|
||||||
on GPU."""
|
on GPU."""
|
||||||
self._test_zero_join(self.device)
|
self._test_zero_join(self.device)
|
||||||
|
|
||||||
@common_distributed.requires_gloo()
|
@requires_gloo()
|
||||||
def test_zero_join_cpu(self):
|
def test_zero_join_cpu(self):
|
||||||
"""Check that the ZeRO join hook allows training with uneven inputs
|
"""Check that the ZeRO join hook allows training with uneven inputs
|
||||||
on CPU."""
|
on CPU."""
|
||||||
self._test_zero_join(torch.device("cpu"))
|
self._test_zero_join("cpu")
|
||||||
|
|
||||||
def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
|
def _test_zero_model_parallel(self, parameters_as_bucket_view: bool, device: str):
|
||||||
# Use two processes each with two GPUs
|
# Use two processes each with two GPUs
|
||||||
assert self.rank < 2
|
assert self.rank < 2
|
||||||
NUM_EPOCHS = 2
|
NUM_EPOCHS = 2
|
||||||
NUM_INPUTS = 4
|
NUM_INPUTS = 4
|
||||||
LR = 0.01
|
LR = 0.01
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.cuda.manual_seed(0)
|
if "cpu" not in device:
|
||||||
|
torch.get_device_module(device).manual_seed(0)
|
||||||
|
|
||||||
class ModelParallelModel(torch.nn.Module):
|
class ModelParallelModel(torch.nn.Module):
|
||||||
def __init__(self, dev0, dev1):
|
def __init__(self, dev0, dev1):
|
||||||
@ -1221,7 +1182,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
atol=1e-04,
|
atol=1e-04,
|
||||||
), "Models differ after a step"
|
), "Models differ after a step"
|
||||||
|
|
||||||
@common_distributed.skip_if_lt_x_gpu(4)
|
@skipIfHpu
|
||||||
|
@skip_if_lt_x_gpu(4)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"parameters_as_bucket_view",
|
"parameters_as_bucket_view",
|
||||||
[False, True],
|
[False, True],
|
||||||
@ -1234,8 +1196,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
layers are assigned to different devices."""
|
layers are assigned to different devices."""
|
||||||
if self.rank >= 2:
|
if self.rank >= 2:
|
||||||
return
|
return
|
||||||
self.dist_init(self.rank, world_size=2)
|
self.create_pg(self.device, world_size=2)
|
||||||
self._test_zero_model_parallel(parameters_as_bucket_view)
|
self._test_zero_model_parallel(parameters_as_bucket_view, self.device)
|
||||||
|
|
||||||
def _test_ddp_zero_overlap(
|
def _test_ddp_zero_overlap(
|
||||||
self,
|
self,
|
||||||
@ -1250,12 +1212,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
SGD_WEIGHT_DECAY = 0.001
|
SGD_WEIGHT_DECAY = 0.001
|
||||||
NUM_INPUTS = 5
|
NUM_INPUTS = 5
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.cuda.manual_seed(0)
|
if "cpu" not in device:
|
||||||
|
torch.get_device_module(device).manual_seed(0)
|
||||||
|
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
is_gpu = device.type == "cuda"
|
|
||||||
if is_gpu:
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
models_to_test = [
|
models_to_test = [
|
||||||
(
|
(
|
||||||
torch.nn.Sequential(
|
torch.nn.Sequential(
|
||||||
@ -1273,11 +1233,16 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
for model, inputs in models_to_test:
|
for model, inputs in models_to_test:
|
||||||
# Enable determinism in cudnn operators
|
# Select deterministic context based on device
|
||||||
with torch.backends.cudnn.flags(
|
det_ctx = (
|
||||||
enabled=True, deterministic=True, benchmark=False
|
torch.backends.cudnn.flags(
|
||||||
):
|
enabled=True, deterministic=True, benchmark=False
|
||||||
device_ids = [rank] if is_gpu else None
|
)
|
||||||
|
if "cuda" in device
|
||||||
|
else torch.use_deterministic_algorithms(True)
|
||||||
|
)
|
||||||
|
with det_ctx:
|
||||||
|
device_ids = [rank] if requires_ddp_rank(device) else None
|
||||||
# Set up the DDP model overlapping with ZeRO
|
# Set up the DDP model overlapping with ZeRO
|
||||||
ddp_model_overlap = DDP(
|
ddp_model_overlap = DDP(
|
||||||
copy.deepcopy(model).to(device),
|
copy.deepcopy(model).to(device),
|
||||||
@ -1374,10 +1339,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
|
|
||||||
# NOTE: The test is skipped if using Windows since functional optimizers
|
# NOTE: The test is skipped if using Windows since functional optimizers
|
||||||
# are not currently supported.
|
# are not currently supported.
|
||||||
@common_distributed.skip_if_win32()
|
@skip_if_win32()
|
||||||
@common_distributed.requires_nccl()
|
@requires_accelerator_dist_backend()
|
||||||
@common_distributed.skip_if_no_gpu
|
@skip_if_no_gpu
|
||||||
@common_distributed.skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"use_gpu",
|
"use_gpu",
|
||||||
[True],
|
[True],
|
||||||
@ -1413,9 +1378,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
|
by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
|
||||||
and DDP arguments achieves parity with DDP using a local optimizer.
|
and DDP arguments achieves parity with DDP using a local optimizer.
|
||||||
"""
|
"""
|
||||||
device = torch.device(self.rank) if use_gpu else torch.device("cpu")
|
self.create_pg(self.device)
|
||||||
backend = _get_backend_for_tests()
|
|
||||||
self.dist_init(self.rank, self.world_size, backend)
|
|
||||||
hook_constructor = (
|
hook_constructor = (
|
||||||
hook_with_zero_step
|
hook_with_zero_step
|
||||||
if not use_interleaved_hook
|
if not use_interleaved_hook
|
||||||
@ -1423,7 +1386,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._test_ddp_zero_overlap(
|
self._test_ddp_zero_overlap(
|
||||||
device,
|
self.device if use_gpu else "cpu",
|
||||||
hook_constructor,
|
hook_constructor,
|
||||||
gradient_as_bucket_view,
|
gradient_as_bucket_view,
|
||||||
static_graph,
|
static_graph,
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
@ -16,10 +15,13 @@ if not dist.is_available():
|
|||||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
|
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
|
||||||
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||||
|
|
||||||
|
|
||||||
|
device_type = str(get_devtype())
|
||||||
|
|
||||||
if TEST_WITH_DEV_DBG_ASAN:
|
if TEST_WITH_DEV_DBG_ASAN:
|
||||||
print(
|
print(
|
||||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||||
@ -27,8 +29,7 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||||||
)
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
BACKEND = dist.Backend.NCCL
|
WORLD_SIZE = min(4, max(2, torch.get_device_module(device_type).device_count()))
|
||||||
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
|
|
||||||
|
|
||||||
|
|
||||||
def with_comms(func=None):
|
def with_comms(func=None):
|
||||||
@ -39,30 +40,16 @@ def with_comms(func=None):
|
|||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
if torch.get_device_module(device_type).device_count() < self.world_size:
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||||
self.dist_init()
|
self.create_pg(device_type)
|
||||||
func(self)
|
func(self)
|
||||||
self.destroy_comms()
|
self.destroy_comms()
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class C10dErrorLoggerTest(MultiProcessTestCase):
|
class C10dErrorLoggerTest(DistributedTestBase):
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
||||||
os.environ["BACKEND"] = BACKEND
|
|
||||||
self._spawn_processes()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return (
|
|
||||||
torch.device(self.rank)
|
|
||||||
if BACKEND == dist.Backend.NCCL
|
|
||||||
else torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
return WORLD_SIZE
|
return WORLD_SIZE
|
||||||
@ -76,18 +63,6 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
def dist_init(self):
|
|
||||||
dist.init_process_group(
|
|
||||||
backend=BACKEND,
|
|
||||||
world_size=self.world_size,
|
|
||||||
rank=self.rank,
|
|
||||||
init_method=f"file://{self.file_name}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# set device for nccl pg for collectives
|
|
||||||
if BACKEND == "nccl":
|
|
||||||
torch.cuda.set_device(self.rank)
|
|
||||||
|
|
||||||
def test_get_or_create_logger(self):
|
def test_get_or_create_logger(self):
|
||||||
self.assertIsNotNone(_c10d_logger)
|
self.assertIsNotNone(_c10d_logger)
|
||||||
self.assertEqual(1, len(_c10d_logger.handlers))
|
self.assertEqual(1, len(_c10d_logger.handlers))
|
||||||
@ -117,7 +92,11 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
|
|||||||
re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
|
re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(error_msg_dict), 9)
|
# NCCL adds additional nccl_version data to the error_msg_dict
|
||||||
|
if self.backend(device_type) == dist.Backend.NCCL:
|
||||||
|
self.assertEqual(len(error_msg_dict), 9)
|
||||||
|
else:
|
||||||
|
self.assertEqual(len(error_msg_dict), 8)
|
||||||
|
|
||||||
self.assertIn("pg_name", error_msg_dict.keys())
|
self.assertIn("pg_name", error_msg_dict.keys())
|
||||||
self.assertEqual("None", error_msg_dict["pg_name"])
|
self.assertEqual("None", error_msg_dict["pg_name"])
|
||||||
@ -126,13 +105,14 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
|
|||||||
self.assertEqual("broadcast", error_msg_dict["func_name"])
|
self.assertEqual("broadcast", error_msg_dict["func_name"])
|
||||||
|
|
||||||
self.assertIn("backend", error_msg_dict.keys())
|
self.assertIn("backend", error_msg_dict.keys())
|
||||||
self.assertEqual("nccl", error_msg_dict["backend"])
|
self.assertEqual(self.backend(device_type), error_msg_dict["backend"])
|
||||||
|
|
||||||
self.assertIn("nccl_version", error_msg_dict.keys())
|
if self.backend(device_type) == dist.Backend.NCCL:
|
||||||
nccl_ver = torch.cuda.nccl.version()
|
self.assertIn("nccl_version", error_msg_dict.keys())
|
||||||
self.assertEqual(
|
nccl_ver = torch.cuda.nccl.version()
|
||||||
".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
|
self.assertEqual(
|
||||||
)
|
".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
|
||||||
|
)
|
||||||
|
|
||||||
# In this test case, group_size = world_size, since we don't have multiple processes on one node.
|
# In this test case, group_size = world_size, since we don't have multiple processes on one node.
|
||||||
self.assertIn("group_size", error_msg_dict.keys())
|
self.assertIn("group_size", error_msg_dict.keys())
|
||||||
|
@ -26,12 +26,16 @@ from torch.testing._internal.common_distributed import (
|
|||||||
_dynamo_dist_per_rank_init,
|
_dynamo_dist_per_rank_init,
|
||||||
at_least_x_gpu,
|
at_least_x_gpu,
|
||||||
DynamoDistributedMultiProcTestCase,
|
DynamoDistributedMultiProcTestCase,
|
||||||
requires_nccl,
|
requires_accelerator_dist_backend,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
from torch.testing._internal.common_utils import skipIfRocm
|
from torch.testing._internal.common_utils import skipIfRocm
|
||||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
|
|
||||||
|
|
||||||
|
device_type = str(get_devtype())
|
||||||
|
|
||||||
|
|
||||||
def get_snode_runtime_for_reorder_compute_test(snode):
|
def get_snode_runtime_for_reorder_compute_test(snode):
|
||||||
# NOTE: custom cost model to show that the compute reordering algorithm is working
|
# NOTE: custom cost model to show that the compute reordering algorithm is working
|
||||||
# Collective kernels
|
# Collective kernels
|
||||||
@ -74,7 +78,7 @@ def create_grouped_node_for_allreduce_and_its_deps(snodes):
|
|||||||
return new_snode_order
|
return new_snode_order
|
||||||
|
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend()
|
||||||
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
"""
|
"""
|
||||||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||||
@ -113,9 +117,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
return torch.matmul(ar, b)
|
return torch.matmul(ar, b)
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with _dynamo_dist_per_rank_init(
|
||||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
):
|
):
|
||||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
code = run_and_get_triton_code(compiled, inputs)
|
code = run_and_get_triton_code(compiled, inputs)
|
||||||
# Verify that the wait_tensor is sinked below the 1st matmul but
|
# Verify that the wait_tensor is sinked below the 1st matmul but
|
||||||
@ -154,9 +161,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
return torch.matmul(d, e)
|
return torch.matmul(d, e)
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with _dynamo_dist_per_rank_init(
|
||||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
):
|
):
|
||||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
code = run_and_get_triton_code(compiled, inputs)
|
code = run_and_get_triton_code(compiled, inputs)
|
||||||
# Verify that the all_reduce_ has been raised above the 2nd matmul
|
# Verify that the all_reduce_ has been raised above the 2nd matmul
|
||||||
@ -202,9 +212,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
return torch.mm(e, g)
|
return torch.mm(e, g)
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with _dynamo_dist_per_rank_init(
|
||||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
):
|
):
|
||||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||||
# Things to verify:
|
# Things to verify:
|
||||||
@ -255,9 +268,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
return (e,)
|
return (e,)
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with _dynamo_dist_per_rank_init(
|
||||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
):
|
):
|
||||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||||
# NOTE: after scheduling the first all_reduce:
|
# NOTE: after scheduling the first all_reduce:
|
||||||
@ -312,9 +328,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
return (e,)
|
return (e,)
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with _dynamo_dist_per_rank_init(
|
||||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
):
|
):
|
||||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||||
# NOTE: after scheduling the first all_reduce:
|
# NOTE: after scheduling the first all_reduce:
|
||||||
@ -362,9 +381,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
return (mm,)
|
return (mm,)
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with _dynamo_dist_per_rank_init(
|
||||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
):
|
):
|
||||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||||
# Expectations:
|
# Expectations:
|
||||||
@ -387,9 +409,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
ranks = pg_info["ranks"]
|
ranks = pg_info["ranks"]
|
||||||
group_size = pg_info["group_size"]
|
group_size = pg_info["group_size"]
|
||||||
|
|
||||||
g1 = torch.ones(10, 10, device="cuda")
|
g1 = torch.ones(10, 10, device=device_type)
|
||||||
g2 = torch.ones(11, 11, device="cuda")
|
g2 = torch.ones(11, 11, device=device_type)
|
||||||
g3 = torch.ones(12, 12, device="cuda")
|
g3 = torch.ones(12, 12, device=device_type)
|
||||||
|
|
||||||
def assert_pass(graph):
|
def assert_pass(graph):
|
||||||
# all_reduces need to remain in order!
|
# all_reduces need to remain in order!
|
||||||
@ -429,7 +451,9 @@ graph():
|
|||||||
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
|
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
|
||||||
return grad3, grad2, grad1
|
return grad3, grad2, grad1
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size, fake_pg=True):
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank, self.world_size, self.backend(device_type), fake_pg=True
|
||||||
|
):
|
||||||
fn(g1, g2, g3)
|
fn(g1, g2, g3)
|
||||||
|
|
||||||
def test_nccl_heuristics(self):
|
def test_nccl_heuristics(self):
|
||||||
|
@ -17,7 +17,9 @@ from torch.distributed.tensor.parallel import (
|
|||||||
)
|
)
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_distributed import HAS_ACCELERATOR
|
||||||
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
|
from torch.testing._internal.common_utils import run_tests, skipIfHpu, TestCase
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
||||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||||
|
|
||||||
@ -26,13 +28,16 @@ if not dist.is_available():
|
|||||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
HAS_CUDA = torch.cuda.is_available()
|
device_type = get_devtype().type
|
||||||
|
|
||||||
|
|
||||||
class TestFakePG(TestCase):
|
class TestFakePG(TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
dist.destroy_process_group()
|
try:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
def test_all_reduce(self):
|
def test_all_reduce(self):
|
||||||
store = FakeStore()
|
store = FakeStore()
|
||||||
@ -62,20 +67,21 @@ class TestFakePG(TestCase):
|
|||||||
dist.reduce_scatter(output_tensor, to_reduce_scatter)
|
dist.reduce_scatter(output_tensor, to_reduce_scatter)
|
||||||
self.assertEqual(tuple(output_tensor.shape), (3, 3))
|
self.assertEqual(tuple(output_tensor.shape), (3, 3))
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||||
def test_construct_fsdp(self):
|
def test_construct_fsdp(self):
|
||||||
store = FakeStore()
|
store = FakeStore()
|
||||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||||
FSDP(nn.Linear(2, 3, device="cuda"))
|
FSDP(nn.Linear(2, 3, device=device_type))
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
@skipIfHpu
|
||||||
|
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||||
def test_fsdp_fake_e2e(self):
|
def test_fsdp_fake_e2e(self):
|
||||||
store = dist.HashStore()
|
store = dist.HashStore()
|
||||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||||
my_module = nn.Sequential(
|
my_module = nn.Sequential(
|
||||||
nn.Linear(2, 3, device="cuda"),
|
nn.Linear(2, 3, device=device_type),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(3, 2, device="cuda"),
|
nn.Linear(3, 2, device=device_type),
|
||||||
)
|
)
|
||||||
sharded_module = FSDP(my_module, use_orig_params=True)
|
sharded_module = FSDP(my_module, use_orig_params=True)
|
||||||
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
|
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
|
||||||
@ -85,7 +91,8 @@ class TestFakePG(TestCase):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
@skipIfHpu
|
||||||
|
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||||
def test_fake_pg_tracing(self):
|
def test_fake_pg_tracing(self):
|
||||||
store = dist.HashStore()
|
store = dist.HashStore()
|
||||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||||
@ -95,7 +102,7 @@ class TestFakePG(TestCase):
|
|||||||
def allgather_fn(tensor):
|
def allgather_fn(tensor):
|
||||||
return funcol.all_gather_tensor(tensor, 0, default_pg)
|
return funcol.all_gather_tensor(tensor, 0, default_pg)
|
||||||
|
|
||||||
gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda"))
|
gm = make_fx(allgather_fn)(torch.randn(2, 2, device=device_type))
|
||||||
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
|
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
|
||||||
|
|
||||||
def test_broadcast(self):
|
def test_broadcast(self):
|
||||||
@ -165,7 +172,8 @@ class TestFakePG(TestCase):
|
|||||||
dist.recv(output, 1)
|
dist.recv(output, 1)
|
||||||
self.assertEqual(tuple(output.shape), (3, 3))
|
self.assertEqual(tuple(output.shape), (3, 3))
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP")
|
@skipIfHpu
|
||||||
|
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||||
def test_fsdp_tp_fake_e2e(self):
|
def test_fsdp_tp_fake_e2e(self):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
tp_size = 2
|
tp_size = 2
|
||||||
@ -175,9 +183,11 @@ class TestFakePG(TestCase):
|
|||||||
backend="fake", rank=0, world_size=world_size, store=store
|
backend="fake", rank=0, world_size=world_size, store=store
|
||||||
)
|
)
|
||||||
|
|
||||||
device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size))
|
device_mesh = DeviceMesh(
|
||||||
|
device_type, torch.arange(0, world_size).view(-1, tp_size)
|
||||||
|
)
|
||||||
device_mesh = init_device_mesh(
|
device_mesh = init_device_mesh(
|
||||||
"cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
|
device_type, (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_parallelize_plan = {
|
sequence_parallelize_plan = {
|
||||||
@ -190,7 +200,7 @@ class TestFakePG(TestCase):
|
|||||||
}
|
}
|
||||||
for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
|
for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
|
||||||
my_module = parallelize_module(
|
my_module = parallelize_module(
|
||||||
MLPModule(device="cuda"),
|
MLPModule(device=device_type),
|
||||||
device_mesh["tp"],
|
device_mesh["tp"],
|
||||||
parallel_plan,
|
parallel_plan,
|
||||||
)
|
)
|
||||||
@ -203,7 +213,7 @@ class TestFakePG(TestCase):
|
|||||||
for i in range(10):
|
for i in range(10):
|
||||||
dp_rank = dist.get_rank()
|
dp_rank = dist.get_rank()
|
||||||
torch.manual_seed(i + dp_rank)
|
torch.manual_seed(i + dp_rank)
|
||||||
input = torch.randn(20, 10).cuda(dist.get_rank())
|
input = torch.randn(20, 10, device=f"{device_type}:{dp_rank}")
|
||||||
x = sharded_module(input)
|
x = sharded_module(input)
|
||||||
loss = x.sum()
|
loss = x.sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -39,6 +39,7 @@ from torch.testing._internal.common_utils import (
|
|||||||
retry_on_connect_failures,
|
retry_on_connect_failures,
|
||||||
skip_but_pass_in_sandcastle,
|
skip_but_pass_in_sandcastle,
|
||||||
skip_but_pass_in_sandcastle_if,
|
skip_but_pass_in_sandcastle_if,
|
||||||
|
TEST_CUDA,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
TEST_WITH_ROCM,
|
TEST_WITH_ROCM,
|
||||||
TEST_WITH_TSAN,
|
TEST_WITH_TSAN,
|
||||||
@ -55,6 +56,10 @@ from torch.testing._internal.distributed.multi_threaded_pg import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
|
||||||
|
DDP_RANK_DEVICES = ["cuda", "xpu"]
|
||||||
|
HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
|
||||||
|
|
||||||
|
|
||||||
class TestSkip(NamedTuple):
|
class TestSkip(NamedTuple):
|
||||||
exit_code: int
|
exit_code: int
|
||||||
@ -109,21 +114,25 @@ class DistTestCases:
|
|||||||
backend_feature["xpu"] = {"xccl"}
|
backend_feature["xpu"] = {"xccl"}
|
||||||
|
|
||||||
|
|
||||||
|
def requires_ddp_rank(device):
|
||||||
|
return device in DDP_RANK_DEVICES
|
||||||
|
|
||||||
|
|
||||||
def skip_if_no_gpu(func):
|
def skip_if_no_gpu(func):
|
||||||
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
|
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
|
||||||
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
|
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if not torch.cuda.is_available():
|
if not (TEST_CUDA or TEST_HPU or TEST_XPU):
|
||||||
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
if torch.cuda.device_count() < world_size:
|
if TEST_CUDA and torch.cuda.device_count() < world_size:
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||||
if TEST_HPU and torch.hpu.device_count < world_size:
|
if TEST_HPU and torch.hpu.device_count() < world_size:
|
||||||
|
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||||
|
if TEST_XPU and torch.xpu.device_count() < world_size:
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||||
if TEST_XPU and torch.xpu.device_count < world_size:
|
|
||||||
sys.exit(TEST_SKIPS[f"multi-xpu-{world_size}"].exit_code)
|
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@ -189,7 +198,13 @@ def import_transformers_or_skip():
|
|||||||
|
|
||||||
|
|
||||||
def at_least_x_gpu(x):
|
def at_least_x_gpu(x):
|
||||||
return torch.cuda.is_available() and torch.cuda.device_count() >= x
|
if TEST_CUDA and torch.cuda.device_count() >= x:
|
||||||
|
return True
|
||||||
|
if TEST_HPU and torch.hpu.device_count() >= x:
|
||||||
|
return True
|
||||||
|
if TEST_XPU and torch.xpu.device_count() >= x:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def skip_if_lt_x_gpu(x):
|
def skip_if_lt_x_gpu(x):
|
||||||
@ -355,6 +370,35 @@ def requires_mpi():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def requires_accelerator_dist_backend(backends=None):
|
||||||
|
"""
|
||||||
|
Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
|
||||||
|
If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
callable: A decorator that skips the test if no specified accelerator backend is available.
|
||||||
|
"""
|
||||||
|
if backends is None:
|
||||||
|
backends = ACCELERATOR_DIST_BACKENDS
|
||||||
|
|
||||||
|
backend_available = any(
|
||||||
|
{
|
||||||
|
"nccl": c10d.is_nccl_available,
|
||||||
|
"xccl": c10d.is_xccl_available,
|
||||||
|
"hccl": lambda: TEST_HPU,
|
||||||
|
}.get(backend, lambda: False)()
|
||||||
|
for backend in backends
|
||||||
|
)
|
||||||
|
|
||||||
|
return skip_but_pass_in_sandcastle_if(
|
||||||
|
not backend_available,
|
||||||
|
f"No accelerator communication backend available among {backends}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def requires_multicast_support():
|
def requires_multicast_support():
|
||||||
has_multicast_support = (
|
has_multicast_support = (
|
||||||
torch.cuda.is_available()
|
torch.cuda.is_available()
|
||||||
@ -968,9 +1012,14 @@ class MultiProcessTestCase(TestCase):
|
|||||||
class DistributedTestBase(MultiProcessTestCase):
|
class DistributedTestBase(MultiProcessTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||||
self._spawn_processes()
|
self._spawn_processes()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
try:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
os.remove(self.file_name)
|
os.remove(self.file_name)
|
||||||
except OSError:
|
except OSError:
|
||||||
@ -986,12 +1035,14 @@ class DistributedTestBase(MultiProcessTestCase):
|
|||||||
else:
|
else:
|
||||||
return "gloo"
|
return "gloo"
|
||||||
|
|
||||||
def create_pg(self, device):
|
def create_pg(self, device, world_size=None):
|
||||||
|
if world_size is None:
|
||||||
|
world_size = self.world_size
|
||||||
num_visible_devices = torch.get_device_module(device).device_count()
|
num_visible_devices = torch.get_device_module(device).device_count()
|
||||||
store = torch.distributed.FileStore(self.file_name, num_visible_devices)
|
store = torch.distributed.FileStore(self.file_name, num_visible_devices)
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend=self.backend(device),
|
backend=self.backend(device),
|
||||||
world_size=self.world_size,
|
world_size=world_size,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
store=store,
|
store=store,
|
||||||
)
|
)
|
||||||
@ -1404,7 +1455,9 @@ class SaveForwardInputsModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
|
def _dynamo_dist_per_rank_init(
|
||||||
|
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
|
||||||
|
):
|
||||||
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
||||||
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
||||||
if not fake_pg:
|
if not fake_pg:
|
||||||
@ -1421,7 +1474,7 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
|
|||||||
store=store,
|
store=store,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
|
c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
torch._dynamo.utils.counters.clear()
|
torch._dynamo.utils.counters.clear()
|
||||||
try:
|
try:
|
||||||
@ -1465,7 +1518,7 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
|
|||||||
super().tearDownClass()
|
super().tearDownClass()
|
||||||
|
|
||||||
|
|
||||||
class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
class DynamoDistributedMultiProcTestCase(DistributedTestBase):
|
||||||
"""
|
"""
|
||||||
Use this for tests that actually run on multiple GPUs.
|
Use this for tests that actually run on multiple GPUs.
|
||||||
|
|
||||||
@ -1476,20 +1529,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
|||||||
sparingly for integration tests.
|
sparingly for integration tests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
self._spawn_processes()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
try:
|
|
||||||
os.remove(self.file_name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
return torch.cuda.device_count()
|
return torch.accelerator.device_count()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run(
|
def _run(
|
||||||
|
Reference in New Issue
Block a user