Anthony Barbier
2025-10-02 15:48:47 +00:00
committed by PyTorch MergeBot
parent c6329524d8
commit ac7b4e7fe4
13 changed files with 226 additions and 104 deletions

View File

@ -21,6 +21,16 @@ from _pytest.terminal import _get_raw_skip_reason
from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
try:
from torch.testing._internal.common_utils import parse_cmd_line_args
except ImportError:
# Temporary workaround needed until parse_cmd_line_args makes it into a nightlye because
# main / PR's tests are sometimes run against the previous day's nightly which won't
# have this function.
def parse_cmd_line_args():
pass
if TYPE_CHECKING:
from _pytest._code.code import ReprFileLocation
@ -83,6 +93,7 @@ def pytest_addoption(parser: Parser) -> None:
def pytest_configure(config: Config) -> None:
parse_cmd_line_args()
xmlpath = config.option.xmlpath_reruns
# Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"):

View File

@ -27,6 +27,9 @@ from torch.testing._internal.jit_utils import (
)
assert GRAPH_EXECUTOR is not None
@unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
)

View File

@ -3,6 +3,13 @@
import torch
if __name__ == '__main__':
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
# This is how we include tests located in test/jit/...
# They are included here so that they are invoked when you call `test_jit.py`,
# do not run these test files directly.
@ -97,7 +104,7 @@ import torch.nn.functional as F
from torch.testing._internal import jit_utils
from torch.testing._internal.common_jit import check_against_reference
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
skipIfCrossRef, skipIfTorchDynamo
@ -158,6 +165,7 @@ def doAutodiffCheck(testname):
if "test_t_" in testname or testname == "test_t":
return False
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
return False
@ -201,6 +209,7 @@ def doAutodiffCheck(testname):
return testname not in test_exceptions
assert GRAPH_EXECUTOR
# TODO: enable TE in PE when all tests are fixed
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)

View File

@ -5,12 +5,17 @@ from torch.cuda.amp import autocast
from typing import Optional
import unittest
from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo
from torch.testing import FileCheck
from jit.test_models import MnistNet
if __name__ == '__main__':
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from test_jit import JitTestCase
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
@skipIfTorchDynamo("Not a TorchDynamo suitable test")

View File

@ -9,6 +9,13 @@ import torch.nn.functional as F
from torch.testing import FileCheck
from unittest import skipIf
if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \

View File

@ -2,6 +2,14 @@
import sys
sys.argv.append("--jit-executor=legacy")
if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from test_jit_fuser import * # noqa: F403
if __name__ == '__main__':

View File

@ -22,6 +22,13 @@ from torch.testing import FileCheck
torch._C._jit_set_profiling_executor(True)
torch._C._get_graph_executor_optimize(True)
if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from itertools import combinations, permutations, product
from textwrap import dedent

View File

@ -2,7 +2,14 @@
import sys
sys.argv.append("--jit-executor=legacy")
from test_jit import * # noqa: F403
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests
if __name__ == '__main__':
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from test_jit import * # noqa: F403, F401
if __name__ == '__main__':
run_tests()

View File

@ -33,6 +33,7 @@ import torch.nn as nn
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch._logging._internal import trace_log
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
find_free_port,
@ -772,7 +773,12 @@ class MultiProcessTestCase(TestCase):
process = proc(
target=self.__class__._run,
name="process " + str(rank),
args=(rank, self._current_test_name(), self.file_name, child_conn),
args=(
rank,
self._current_test_name(),
self.file_name,
child_conn,
),
kwargs={
"fake_pg": getattr(self, "fake_pg", False),
},
@ -849,6 +855,7 @@ class MultiProcessTestCase(TestCase):
torch._C._set_print_stack_traces_on_fatal_signal(True)
# Show full C++ stacktraces when a Python error originating from C++ is raised.
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
common_utils.set_rng_seed()
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retrieving a corresponding test and executing it.
@ -1670,6 +1677,10 @@ class MultiProcContinuousTest(TestCase):
self.rank = cls.rank
self.world_size = cls.world_size
test_fn = getattr(self, test_name)
# Ensure all the ranks use the same seed.
common_utils.set_rng_seed()
# Run the test function
test_fn(**kwargs)

View File

@ -58,6 +58,7 @@ from torch.testing._internal.common_distributed import (
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
get_cycles_per_ms,
set_rng_seed,
TEST_CUDA,
TEST_HPU,
TEST_XPU,
@ -1228,6 +1229,7 @@ class FSDPTest(MultiProcessTestCase):
dist.barrier(device_ids=device_ids)
torch._dynamo.reset()
set_rng_seed()
self.run_test(test_name, pipe)
torch._dynamo.reset()

View File

@ -15,6 +15,7 @@ import torch.cuda
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import _reduction as _Reduction
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
@ -1078,6 +1079,7 @@ def single_batch_reference_fn(input, parameters, module):
def get_new_module_tests():
common_utils.set_rng_seed()
new_module_tests = [
poissonnllloss_no_reduce_test(),
bceloss_no_reduce_test(),

View File

@ -101,9 +101,35 @@ except ImportError:
has_pytest = False
SEED = 1234
MI300_ARCH = ("gfx942",)
MI200_ARCH = ("gfx90a")
class ProfilingMode(Enum):
LEGACY = 1
SIMPLE = 2
PROFILING = 3
# Set by parse_cmd_line_args() if called
CI_FUNCTORCH_ROOT = ""
CI_PT_ROOT = ""
CI_TEST_PREFIX = ""
DISABLED_TESTS_FILE = ""
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
LOG_SUFFIX = ""
PYTEST_SINGLE_TEST = ""
REPEAT_COUNT = 0
RERUN_DISABLED_TESTS = False
RUN_PARALLEL = 0
SHOWLOCALS = False
SLOW_TESTS_FILE = ""
TEST_BAILOUTS = False
TEST_DISCOVER = False
TEST_IN_SUBPROCESS = False
TEST_SAVE_XML = ""
UNITTEST_ARGS : list[str] = []
USE_PYTEST = False
def freeze_rng_state(*args, **kwargs):
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
@ -838,11 +864,6 @@ class decorateIf(_TestParametrizer):
yield (test_wrapper, test_name, {}, decorator_fn)
class ProfilingMode(Enum):
LEGACY = 1
SIMPLE = 2
PROFILING = 3
def cppProfilingFlagsToProfilingMode():
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
@ -861,6 +882,7 @@ def cppProfilingFlagsToProfilingMode():
def enable_profiling_mode_for_profiling_tests():
old_prof_exec_state = False
old_prof_mode_state = False
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
@ -895,6 +917,7 @@ meth_call = torch._C.ScriptMethod.__call__
def prof_callable(callable, *args, **kwargs):
if 'profile_and_replay' in kwargs:
del kwargs['profile_and_replay']
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
with enable_profiling_mode_for_profiling_tests():
callable(*args, **kwargs)
@ -924,72 +947,91 @@ def _get_test_report_path():
test_source = override if override is not None else 'python-unittest'
return os.path.join('test-reports', test_source)
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
parser.add_argument('--subprocess', action='store_true',
help='whether to run each test in a subprocess')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--accept', action='store_true')
parser.add_argument('--jit-executor', '--jit_executor', type=str)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
parser.add_argument('--rerun-disabled-tests', action='store_true')
parser.add_argument('--pytest-single-test', type=str, nargs=1)
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
def parse_cmd_line_args():
global CI_FUNCTORCH_ROOT
global CI_PT_ROOT
global CI_TEST_PREFIX
global DISABLED_TESTS_FILE
global GRAPH_EXECUTOR
global LOG_SUFFIX
global PYTEST_SINGLE_TEST
global REPEAT_COUNT
global RERUN_DISABLED_TESTS
global RUN_PARALLEL
global SHOWLOCALS
global SLOW_TESTS_FILE
global TEST_BAILOUTS
global TEST_DISCOVER
global TEST_IN_SUBPROCESS
global TEST_SAVE_XML
global UNITTEST_ARGS
global USE_PYTEST
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
parser.add_argument('--subprocess', action='store_true',
help='whether to run each test in a subprocess')
parser.add_argument('--accept', action='store_true')
parser.add_argument('--jit-executor', '--jit_executor', type=str)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
parser.add_argument('--rerun-disabled-tests', action='store_true')
parser.add_argument('--pytest-single-test', type=str, nargs=1)
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
# Only run when -h or --help flag is active to display both unittest and parser help messages.
def run_unittest_help(argv):
unittest.main(argv=argv)
def run_unittest_help(argv):
unittest.main(argv=argv)
if '-h' in sys.argv or '--help' in sys.argv:
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
help_thread.start()
help_thread.join()
if '-h' in sys.argv or '--help' in sys.argv:
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
help_thread.start()
help_thread.join()
args, remaining = parser.parse_known_args()
if args.jit_executor == 'legacy':
GRAPH_EXECUTOR = ProfilingMode.LEGACY
elif args.jit_executor == 'profiling':
GRAPH_EXECUTOR = ProfilingMode.PROFILING
elif args.jit_executor == 'simple':
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
else:
# infer flags based on the default settings
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
args, remaining = parser.parse_known_args()
if args.jit_executor == 'legacy':
GRAPH_EXECUTOR = ProfilingMode.LEGACY
elif args.jit_executor == 'profiling':
GRAPH_EXECUTOR = ProfilingMode.PROFILING
elif args.jit_executor == 'simple':
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
else:
# infer flags based on the default settings
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
SLOW_TESTS_FILE = args.import_slow_tests
DISABLED_TESTS_FILE = args.import_disabled_tests
LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
USE_PYTEST = args.use_pytest
PYTEST_SINGLE_TEST = args.pytest_single_test
TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml
REPEAT_COUNT = args.repeat
SEED = args.seed
SHOWLOCALS = args.showlocals
if not getattr(expecttest, "ACCEPT", False):
expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)
SLOW_TESTS_FILE = args.import_slow_tests
DISABLED_TESTS_FILE = args.import_disabled_tests
LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
USE_PYTEST = args.use_pytest
PYTEST_SINGLE_TEST = args.pytest_single_test
TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml
REPEAT_COUNT = args.repeat
SHOWLOCALS = args.showlocals
if not getattr(expecttest, "ACCEPT", False):
expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining
set_rng_seed()
# CI Prefix path used only on CI environment
CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
def wait_for_process(p, timeout=None):
try:
@ -1138,7 +1180,9 @@ def lint_test_case_extension(suite):
return succeed
def get_report_path(argv=UNITTEST_ARGS, pytest=False):
def get_report_path(argv=None, pytest=False):
if argv is None:
argv = UNITTEST_ARGS
test_filename = sanitize_test_filename(argv[0])
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename)
@ -1189,7 +1233,11 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]:
return test_collector_plugin.tests
def run_tests(argv=UNITTEST_ARGS):
def run_tests(argv=None):
parse_cmd_line_args()
if argv is None:
argv = UNITTEST_ARGS
# import test files.
if SLOW_TESTS_FILE:
if os.path.exists(SLOW_TESTS_FILE):
@ -1759,6 +1807,7 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
if not isinstance(fn, type):
@wraps(fn)
def wrapper(*args, **kwargs):
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
raise unittest.SkipTest(msg)
else:
@ -2379,7 +2428,9 @@ def get_function_arglist(func):
return inspect.getfullargspec(func).args
def set_rng_seed(seed):
def set_rng_seed(seed=None):
if seed is None:
seed = SEED
torch.manual_seed(seed)
random.seed(seed)
if TEST_NUMPY:
@ -3402,7 +3453,7 @@ class TestCase(expecttest.TestCase):
def setUp(self):
check_if_enable(self)
set_rng_seed(SEED)
set_rng_seed()
# Save global check sparse tensor invariants state that can be
# restored from tearDown:

View File

@ -137,16 +137,18 @@ class Foo:
f = Foo(10)
f.bar = 1
foo_cpu_tensor = Foo(torch.randn(3, 3))
# Defer instantiation until the seed is set so that randn() returns the same
# values in all processes.
def create_collectives_object_test_list():
return [
{"key1": 3, "key2": 4, "key3": {"nested": True}},
f,
Foo(torch.randn(3, 3)),
"foo",
[1, 2, True, "string", [4, 5, "nested"]],
]
COLLECTIVES_OBJECT_TEST_LIST = [
{"key1": 3, "key2": 4, "key3": {"nested": True}},
f,
foo_cpu_tensor,
"foo",
[1, 2, True, "string", [4, 5, "nested"]],
]
# Allowlist of distributed backends where profiling collectives is supported.
PROFILING_SUPPORTED_BACKENDS = [
@ -396,12 +398,6 @@ class ControlFlowToyModel(nn.Module):
return F.relu(self.lin1(x))
DDP_NET = Net()
BN_NET = BatchNormNet()
BN_NET_NO_AFFINE = BatchNormNet(affine=False)
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
def get_timeout(test_id):
test_name = test_id.split(".")[-1]
if test_name in CUSTOMIZED_TIMEOUT:
@ -4293,7 +4289,7 @@ class DistributedTest:
# as baseline
# cpu training setup
model = DDP_NET
model = Net()
# single gpu training setup
model_gpu = copy.deepcopy(model)
@ -4348,7 +4344,7 @@ class DistributedTest:
_group, _group_id, rank = self._init_global_test()
# cpu training setup
model_base = DDP_NET
model_base = Net()
# DDP-CPU training setup
model_DDP = copy.deepcopy(model_base)
@ -5497,7 +5493,7 @@ class DistributedTest:
def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
torch.manual_seed(31415)
# Creates model and optimizer in default precision
model = copy.deepcopy(DDP_NET).cuda()
model = Net().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
# Creates a GradScaler once at the beginning of training.
@ -5582,7 +5578,7 @@ class DistributedTest:
# as baseline
# cpu training setup
model = BN_NET if affine else BN_NET_NO_AFFINE
model = BatchNormNet() if affine else BatchNormNet(affine=False)
# single gpu training setup
model_gpu = copy.deepcopy(model)
@ -5632,6 +5628,7 @@ class DistributedTest:
def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
learning_rate = 0.03
DDP_NET = Net()
net = torch.nn.parallel.DistributedDataParallel(
copy.deepcopy(DDP_NET).cuda(),
device_ids=[self.rank],
@ -5698,7 +5695,7 @@ class DistributedTest:
learning_rate = 0.03
net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
Net().cuda(), device_ids=[self.rank]
)
averager = create_averager()
@ -5848,7 +5845,7 @@ class DistributedTest:
bs_offset = int(rank * 2)
global_bs = int(num_processes * 2)
model = ONLY_SBN_NET
model = nn.SyncBatchNorm(2, momentum=0.99)
model_gpu = copy.deepcopy(model).cuda(rank)
model_DDP = nn.parallel.DistributedDataParallel(
model_gpu, device_ids=[rank]
@ -6058,6 +6055,7 @@ class DistributedTest:
def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
self,
):
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
_group, _group_id, rank = self._init_global_test()
model = nn.parallel.DistributedDataParallel(
ONLY_SBN_NET.cuda(rank), device_ids=[rank]
@ -6125,7 +6123,7 @@ class DistributedTest:
def test_DistributedDataParallel_SyncBatchNorm_half(self):
_group, _group_id, rank = self._init_global_test()
model = copy.deepcopy(BN_NET)
model = BatchNormNet()
model = model.half()
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(
@ -6141,7 +6139,7 @@ class DistributedTest:
def _test_ddp_logging_data(self, is_gpu):
rank = dist.get_rank()
model_DDP = copy.deepcopy(DDP_NET)
model_DDP = Net()
if is_gpu:
model_DDP = nn.parallel.DistributedDataParallel(
model_DDP.cuda(rank), device_ids=[rank]
@ -6417,7 +6415,7 @@ class DistributedTest:
BACKEND == "nccl", "nccl does not support DDP on CPU models"
)
def test_static_graph_api_cpu(self):
model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
model_DDP = nn.parallel.DistributedDataParallel(Net())
expected_err = "should be called before training loop starts"
with self.assertRaisesRegex(RuntimeError, expected_err):
local_bs = 2
@ -6650,7 +6648,7 @@ class DistributedTest:
def _test_allgather_object(self, subgroup=None):
# Only set device for NCCL backend since it must use GPUs.
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
gather_objects = create_collectives_object_test_list()
backend = os.environ["BACKEND"]
if backend == "nccl":
@ -6694,7 +6692,7 @@ class DistributedTest:
def _test_gather_object(self, pg=None):
# Ensure stateful objects can be gathered
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
gather_objects = create_collectives_object_test_list()
my_rank = dist.get_rank(pg)
backend = os.environ["BACKEND"]
@ -7264,7 +7262,7 @@ class DistributedTest:
return x
torch.cuda.set_device(self.rank)
model_bn = BN_NET
model_bn = BatchNormNet()
model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
copy.deepcopy(model_bn)
).cuda(self.rank)
@ -7560,7 +7558,7 @@ class DistributedTest:
loss.backward()
def _test_broadcast_object_list(self, group=None):
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
gather_objects = create_collectives_object_test_list()
# Only set device for NCCL backend since it must use GPUs.
# Case where rank != GPU device.
@ -8284,10 +8282,11 @@ class DistributedTest:
@require_backend_is_available({"gloo"})
def test_scatter_object_list(self):
src_rank = 0
collectives_object_test_list = create_collectives_object_test_list()
scatter_list = (
COLLECTIVES_OBJECT_TEST_LIST
collectives_object_test_list
if self.rank == src_rank
else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
else [None for _ in collectives_object_test_list]
)
world_size = dist.get_world_size()
scatter_list = scatter_list[:world_size]
@ -8300,8 +8299,8 @@ class DistributedTest:
dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
self.assertEqual(
output_obj_list[0],
COLLECTIVES_OBJECT_TEST_LIST[
self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
collectives_object_test_list[
self.rank % len(collectives_object_test_list)
],
)
# Ensure errors are raised upon incorrect arguments.
@ -9987,7 +9986,7 @@ class DistributedTest:
"Only Nccl & Gloo backend support DistributedDataParallel",
)
def test_sync_bn_logged(self):
model = BN_NET
model = BatchNormNet()
rank = self.rank
# single gpu training setup
model_gpu = model.cuda(rank)