diff --git a/test/conftest.py b/test/conftest.py index 078e4b3b2b8e..d742430f886d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,16 +21,6 @@ from _pytest.terminal import _get_raw_skip_reason from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin -try: - from torch.testing._internal.common_utils import parse_cmd_line_args -except ImportError: - # Temporary workaround needed until parse_cmd_line_args makes it into a nightlye because - # main / PR's tests are sometimes run against the previous day's nightly which won't - # have this function. - def parse_cmd_line_args(): - pass - - if TYPE_CHECKING: from _pytest._code.code import ReprFileLocation @@ -93,7 +83,6 @@ def pytest_addoption(parser: Parser) -> None: def pytest_configure(config: Config) -> None: - parse_cmd_line_args() xmlpath = config.option.xmlpath_reruns # Prevent opening xmllog on worker nodes (xdist). if xmlpath and not hasattr(config, "workerinput"): diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index f128f9c7eec3..f42aa7f8f436 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -27,9 +27,6 @@ from torch.testing._internal.jit_utils import ( ) -assert GRAPH_EXECUTOR is not None - - @unittest.skipIf( GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" ) diff --git a/test/test_jit.py b/test/test_jit.py index 83407e25d0b5..093753851f54 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3,13 +3,6 @@ import torch -if __name__ == '__main__': - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - # This is how we include tests located in test/jit/... # They are included here so that they are invoked when you call `test_jit.py`, # do not run these test files directly. @@ -104,7 +97,7 @@ import torch.nn.functional as F from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ - GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ + suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -165,7 +158,6 @@ def doAutodiffCheck(testname): if "test_t_" in testname or testname == "test_t": return False - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: return False @@ -209,7 +201,6 @@ def doAutodiffCheck(testname): return testname not in test_exceptions -assert GRAPH_EXECUTOR # TODO: enable TE in PE when all tests are fixed torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index dcdf78ff4b89..b3cf4d9bee8f 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -5,17 +5,12 @@ from torch.cuda.amp import autocast from typing import Optional import unittest +from test_jit import JitTestCase from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo from torch.testing import FileCheck from jit.test_models import MnistNet -if __name__ == '__main__': - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - -from test_jit import JitTestCase TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() @skipIfTorchDynamo("Not a TorchDynamo suitable test") diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 5446770695c4..1ac7803a9d46 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -9,13 +9,6 @@ import torch.nn.functional as F from torch.testing import FileCheck from unittest import skipIf -if __name__ == "__main__": - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ diff --git a/test/test_jit_fuser_legacy.py b/test/test_jit_fuser_legacy.py index 4100bcc3e182..3bd8c9497ce0 100644 --- a/test/test_jit_fuser_legacy.py +++ b/test/test_jit_fuser_legacy.py @@ -2,14 +2,6 @@ import sys sys.argv.append("--jit-executor=legacy") - -if __name__ == "__main__": - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - from test_jit_fuser import * # noqa: F403 if __name__ == '__main__': diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 1bda41f7f8f1..c3e26d37da1b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -22,13 +22,6 @@ from torch.testing import FileCheck torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) -if __name__ == "__main__": - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - from itertools import combinations, permutations, product from textwrap import dedent diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 480b57a55bd4..5576f1645349 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -2,14 +2,7 @@ import sys sys.argv.append("--jit-executor=legacy") -from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests - -if __name__ == '__main__': - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - -from test_jit import * # noqa: F403, F401 +from test_jit import * # noqa: F403 if __name__ == '__main__': run_tests() diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 22767567ac1f..b445f4ad8535 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -33,7 +33,6 @@ import torch.nn as nn from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory from torch._logging._internal import trace_log -from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( FILE_SCHEMA, find_free_port, @@ -773,12 +772,7 @@ class MultiProcessTestCase(TestCase): process = proc( target=self.__class__._run, name="process " + str(rank), - args=( - rank, - self._current_test_name(), - self.file_name, - child_conn, - ), + args=(rank, self._current_test_name(), self.file_name, child_conn), kwargs={ "fake_pg": getattr(self, "fake_pg", False), }, @@ -855,7 +849,6 @@ class MultiProcessTestCase(TestCase): torch._C._set_print_stack_traces_on_fatal_signal(True) # Show full C++ stacktraces when a Python error originating from C++ is raised. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" - common_utils.set_rng_seed() # self.id() == e.g. '__main__.TestDistributed.test_get_rank' # We're retrieving a corresponding test and executing it. @@ -1677,10 +1670,6 @@ class MultiProcContinuousTest(TestCase): self.rank = cls.rank self.world_size = cls.world_size test_fn = getattr(self, test_name) - - # Ensure all the ranks use the same seed. - common_utils.set_rng_seed() - # Run the test function test_fn(**kwargs) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 1d8f74702975..c7274fddd6d3 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -58,7 +58,6 @@ from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_utils import ( FILE_SCHEMA, get_cycles_per_ms, - set_rng_seed, TEST_CUDA, TEST_HPU, TEST_XPU, @@ -1229,7 +1228,6 @@ class FSDPTest(MultiProcessTestCase): dist.barrier(device_ids=device_ids) torch._dynamo.reset() - set_rng_seed() self.run_test(test_name, pipe) torch._dynamo.reset() diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 3b8f277ceef8..135cc6a7bd66 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -15,7 +15,6 @@ import torch.cuda import torch.nn as nn import torch.nn.functional as F from torch.nn import _reduction as _Reduction -from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater @@ -1079,7 +1078,6 @@ def single_batch_reference_fn(input, parameters, module): def get_new_module_tests(): - common_utils.set_rng_seed() new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index be1e30d0f18a..93a6352831f8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -101,35 +101,9 @@ except ImportError: has_pytest = False -SEED = 1234 MI300_ARCH = ("gfx942",) MI200_ARCH = ("gfx90a") -class ProfilingMode(Enum): - LEGACY = 1 - SIMPLE = 2 - PROFILING = 3 - -# Set by parse_cmd_line_args() if called -CI_FUNCTORCH_ROOT = "" -CI_PT_ROOT = "" -CI_TEST_PREFIX = "" -DISABLED_TESTS_FILE = "" -GRAPH_EXECUTOR : Optional[ProfilingMode] = None -LOG_SUFFIX = "" -PYTEST_SINGLE_TEST = "" -REPEAT_COUNT = 0 -RERUN_DISABLED_TESTS = False -RUN_PARALLEL = 0 -SHOWLOCALS = False -SLOW_TESTS_FILE = "" -TEST_BAILOUTS = False -TEST_DISCOVER = False -TEST_IN_SUBPROCESS = False -TEST_SAVE_XML = "" -UNITTEST_ARGS : list[str] = [] -USE_PYTEST = False - def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -864,6 +838,11 @@ class decorateIf(_TestParametrizer): yield (test_wrapper, test_name, {}, decorator_fn) +class ProfilingMode(Enum): + LEGACY = 1 + SIMPLE = 2 + PROFILING = 3 + def cppProfilingFlagsToProfilingMode(): old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -882,7 +861,6 @@ def cppProfilingFlagsToProfilingMode(): def enable_profiling_mode_for_profiling_tests(): old_prof_exec_state = False old_prof_mode_state = False - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -917,7 +895,6 @@ meth_call = torch._C.ScriptMethod.__call__ def prof_callable(callable, *args, **kwargs): if 'profile_and_replay' in kwargs: del kwargs['profile_and_replay'] - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: with enable_profiling_mode_for_profiling_tests(): callable(*args, **kwargs) @@ -947,91 +924,72 @@ def _get_test_report_path(): test_source = override if override is not None else 'python-unittest' return os.path.join('test-reports', test_source) -def parse_cmd_line_args(): - global CI_FUNCTORCH_ROOT - global CI_PT_ROOT - global CI_TEST_PREFIX - global DISABLED_TESTS_FILE - global GRAPH_EXECUTOR - global LOG_SUFFIX - global PYTEST_SINGLE_TEST - global REPEAT_COUNT - global RERUN_DISABLED_TESTS - global RUN_PARALLEL - global SHOWLOCALS - global SLOW_TESTS_FILE - global TEST_BAILOUTS - global TEST_DISCOVER - global TEST_IN_SUBPROCESS - global TEST_SAVE_XML - global UNITTEST_ARGS - global USE_PYTEST - - is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") - parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) - parser.add_argument('--subprocess', action='store_true', - help='whether to run each test in a subprocess') - parser.add_argument('--accept', action='store_true') - parser.add_argument('--jit-executor', '--jit_executor', type=str) - parser.add_argument('--repeat', type=int, default=1) - parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') - parser.add_argument('--use-pytest', action='store_true') - parser.add_argument('--save-xml', nargs='?', type=str, - const=_get_test_report_path(), - default=_get_test_report_path() if IS_CI else None) - parser.add_argument('--discover-tests', action='store_true') - parser.add_argument('--log-suffix', type=str, default="") - parser.add_argument('--run-parallel', type=int, default=1) - parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) - parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) - parser.add_argument('--rerun-disabled-tests', action='store_true') - parser.add_argument('--pytest-single-test', type=str, nargs=1) - parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) +is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") +parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) +parser.add_argument('--subprocess', action='store_true', + help='whether to run each test in a subprocess') +parser.add_argument('--seed', type=int, default=1234) +parser.add_argument('--accept', action='store_true') +parser.add_argument('--jit-executor', '--jit_executor', type=str) +parser.add_argument('--repeat', type=int, default=1) +parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') +parser.add_argument('--use-pytest', action='store_true') +parser.add_argument('--save-xml', nargs='?', type=str, + const=_get_test_report_path(), + default=_get_test_report_path() if IS_CI else None) +parser.add_argument('--discover-tests', action='store_true') +parser.add_argument('--log-suffix', type=str, default="") +parser.add_argument('--run-parallel', type=int, default=1) +parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) +parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) +parser.add_argument('--rerun-disabled-tests', action='store_true') +parser.add_argument('--pytest-single-test', type=str, nargs=1) +parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) # Only run when -h or --help flag is active to display both unittest and parser help messages. - def run_unittest_help(argv): - unittest.main(argv=argv) +def run_unittest_help(argv): + unittest.main(argv=argv) - if '-h' in sys.argv or '--help' in sys.argv: - help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) - help_thread.start() - help_thread.join() +if '-h' in sys.argv or '--help' in sys.argv: + help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) + help_thread.start() + help_thread.join() - args, remaining = parser.parse_known_args() - if args.jit_executor == 'legacy': - GRAPH_EXECUTOR = ProfilingMode.LEGACY - elif args.jit_executor == 'profiling': - GRAPH_EXECUTOR = ProfilingMode.PROFILING - elif args.jit_executor == 'simple': - GRAPH_EXECUTOR = ProfilingMode.SIMPLE - else: - # infer flags based on the default settings - GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() +args, remaining = parser.parse_known_args() +if args.jit_executor == 'legacy': + GRAPH_EXECUTOR = ProfilingMode.LEGACY +elif args.jit_executor == 'profiling': + GRAPH_EXECUTOR = ProfilingMode.PROFILING +elif args.jit_executor == 'simple': + GRAPH_EXECUTOR = ProfilingMode.SIMPLE +else: + # infer flags based on the default settings + GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() - RERUN_DISABLED_TESTS = args.rerun_disabled_tests +RERUN_DISABLED_TESTS = args.rerun_disabled_tests - SLOW_TESTS_FILE = args.import_slow_tests - DISABLED_TESTS_FILE = args.import_disabled_tests - LOG_SUFFIX = args.log_suffix - RUN_PARALLEL = args.run_parallel - TEST_BAILOUTS = args.test_bailouts - USE_PYTEST = args.use_pytest - PYTEST_SINGLE_TEST = args.pytest_single_test - TEST_DISCOVER = args.discover_tests - TEST_IN_SUBPROCESS = args.subprocess - TEST_SAVE_XML = args.save_xml - REPEAT_COUNT = args.repeat - SHOWLOCALS = args.showlocals - if not getattr(expecttest, "ACCEPT", False): - expecttest.ACCEPT = args.accept - UNITTEST_ARGS = [sys.argv[0]] + remaining - - set_rng_seed() +SLOW_TESTS_FILE = args.import_slow_tests +DISABLED_TESTS_FILE = args.import_disabled_tests +LOG_SUFFIX = args.log_suffix +RUN_PARALLEL = args.run_parallel +TEST_BAILOUTS = args.test_bailouts +USE_PYTEST = args.use_pytest +PYTEST_SINGLE_TEST = args.pytest_single_test +TEST_DISCOVER = args.discover_tests +TEST_IN_SUBPROCESS = args.subprocess +TEST_SAVE_XML = args.save_xml +REPEAT_COUNT = args.repeat +SEED = args.seed +SHOWLOCALS = args.showlocals +if not getattr(expecttest, "ACCEPT", False): + expecttest.ACCEPT = args.accept +UNITTEST_ARGS = [sys.argv[0]] + remaining +torch.manual_seed(SEED) # CI Prefix path used only on CI environment - CI_TEST_PREFIX = str(Path(os.getcwd())) - CI_PT_ROOT = str(Path(os.getcwd()).parent) - CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) +CI_TEST_PREFIX = str(Path(os.getcwd())) +CI_PT_ROOT = str(Path(os.getcwd()).parent) +CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) def wait_for_process(p, timeout=None): try: @@ -1180,9 +1138,7 @@ def lint_test_case_extension(suite): return succeed -def get_report_path(argv=None, pytest=False): - if argv is None: - argv = UNITTEST_ARGS +def get_report_path(argv=UNITTEST_ARGS, pytest=False): test_filename = sanitize_test_filename(argv[0]) test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = os.path.join(test_report_path, test_filename) @@ -1233,11 +1189,7 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]: return test_collector_plugin.tests -def run_tests(argv=None): - parse_cmd_line_args() - if argv is None: - argv = UNITTEST_ARGS - +def run_tests(argv=UNITTEST_ARGS): # import test files. if SLOW_TESTS_FILE: if os.path.exists(SLOW_TESTS_FILE): @@ -1807,7 +1759,6 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe if not isinstance(fn, type): @wraps(fn) def wrapper(*args, **kwargs): - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.LEGACY: raise unittest.SkipTest(msg) else: @@ -2428,9 +2379,7 @@ def get_function_arglist(func): return inspect.getfullargspec(func).args -def set_rng_seed(seed=None): - if seed is None: - seed = SEED +def set_rng_seed(seed): torch.manual_seed(seed) random.seed(seed) if TEST_NUMPY: @@ -3453,7 +3402,7 @@ class TestCase(expecttest.TestCase): def setUp(self): check_if_enable(self) - set_rng_seed() + set_rng_seed(SEED) # Save global check sparse tensor invariants state that can be # restored from tearDown: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index cc47b91db54f..c4701432d81d 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -137,18 +137,16 @@ class Foo: f = Foo(10) f.bar = 1 +foo_cpu_tensor = Foo(torch.randn(3, 3)) -# Defer instantiation until the seed is set so that randn() returns the same -# values in all processes. -def create_collectives_object_test_list(): - return [ - {"key1": 3, "key2": 4, "key3": {"nested": True}}, - f, - Foo(torch.randn(3, 3)), - "foo", - [1, 2, True, "string", [4, 5, "nested"]], - ] +COLLECTIVES_OBJECT_TEST_LIST = [ + {"key1": 3, "key2": 4, "key3": {"nested": True}}, + f, + foo_cpu_tensor, + "foo", + [1, 2, True, "string", [4, 5, "nested"]], +] # Allowlist of distributed backends where profiling collectives is supported. PROFILING_SUPPORTED_BACKENDS = [ @@ -398,6 +396,12 @@ class ControlFlowToyModel(nn.Module): return F.relu(self.lin1(x)) +DDP_NET = Net() +BN_NET = BatchNormNet() +BN_NET_NO_AFFINE = BatchNormNet(affine=False) +ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) + + def get_timeout(test_id): test_name = test_id.split(".")[-1] if test_name in CUSTOMIZED_TIMEOUT: @@ -4289,7 +4293,7 @@ class DistributedTest: # as baseline # cpu training setup - model = Net() + model = DDP_NET # single gpu training setup model_gpu = copy.deepcopy(model) @@ -4344,7 +4348,7 @@ class DistributedTest: _group, _group_id, rank = self._init_global_test() # cpu training setup - model_base = Net() + model_base = DDP_NET # DDP-CPU training setup model_DDP = copy.deepcopy(model_base) @@ -5493,7 +5497,7 @@ class DistributedTest: def _test_DistributedDataParallel_with_amp(self, grad_is_view=False): torch.manual_seed(31415) # Creates model and optimizer in default precision - model = Net().cuda() + model = copy.deepcopy(DDP_NET).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.03) # Creates a GradScaler once at the beginning of training. @@ -5578,7 +5582,7 @@ class DistributedTest: # as baseline # cpu training setup - model = BatchNormNet() if affine else BatchNormNet(affine=False) + model = BN_NET if affine else BN_NET_NO_AFFINE # single gpu training setup model_gpu = copy.deepcopy(model) @@ -5628,7 +5632,6 @@ class DistributedTest: def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view): learning_rate = 0.03 - DDP_NET = Net() net = torch.nn.parallel.DistributedDataParallel( copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank], @@ -5695,7 +5698,7 @@ class DistributedTest: learning_rate = 0.03 net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( - Net().cuda(), device_ids=[self.rank] + copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank] ) averager = create_averager() @@ -5845,7 +5848,7 @@ class DistributedTest: bs_offset = int(rank * 2) global_bs = int(num_processes * 2) - model = nn.SyncBatchNorm(2, momentum=0.99) + model = ONLY_SBN_NET model_gpu = copy.deepcopy(model).cuda(rank) model_DDP = nn.parallel.DistributedDataParallel( model_gpu, device_ids=[rank] @@ -6055,7 +6058,6 @@ class DistributedTest: def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( self, ): - ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) _group, _group_id, rank = self._init_global_test() model = nn.parallel.DistributedDataParallel( ONLY_SBN_NET.cuda(rank), device_ids=[rank] @@ -6123,7 +6125,7 @@ class DistributedTest: def test_DistributedDataParallel_SyncBatchNorm_half(self): _group, _group_id, rank = self._init_global_test() - model = BatchNormNet() + model = copy.deepcopy(BN_NET) model = model.half() model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.parallel.DistributedDataParallel( @@ -6139,7 +6141,7 @@ class DistributedTest: def _test_ddp_logging_data(self, is_gpu): rank = dist.get_rank() - model_DDP = Net() + model_DDP = copy.deepcopy(DDP_NET) if is_gpu: model_DDP = nn.parallel.DistributedDataParallel( model_DDP.cuda(rank), device_ids=[rank] @@ -6415,7 +6417,7 @@ class DistributedTest: BACKEND == "nccl", "nccl does not support DDP on CPU models" ) def test_static_graph_api_cpu(self): - model_DDP = nn.parallel.DistributedDataParallel(Net()) + model_DDP = nn.parallel.DistributedDataParallel(DDP_NET) expected_err = "should be called before training loop starts" with self.assertRaisesRegex(RuntimeError, expected_err): local_bs = 2 @@ -6648,7 +6650,7 @@ class DistributedTest: def _test_allgather_object(self, subgroup=None): # Only set device for NCCL backend since it must use GPUs. - gather_objects = create_collectives_object_test_list() + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() backend = os.environ["BACKEND"] if backend == "nccl": @@ -6692,7 +6694,7 @@ class DistributedTest: def _test_gather_object(self, pg=None): # Ensure stateful objects can be gathered - gather_objects = create_collectives_object_test_list() + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() my_rank = dist.get_rank(pg) backend = os.environ["BACKEND"] @@ -7262,7 +7264,7 @@ class DistributedTest: return x torch.cuda.set_device(self.rank) - model_bn = BatchNormNet() + model_bn = BN_NET model_bn = nn.SyncBatchNorm.convert_sync_batchnorm( copy.deepcopy(model_bn) ).cuda(self.rank) @@ -7558,7 +7560,7 @@ class DistributedTest: loss.backward() def _test_broadcast_object_list(self, group=None): - gather_objects = create_collectives_object_test_list() + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() # Only set device for NCCL backend since it must use GPUs. # Case where rank != GPU device. @@ -8282,11 +8284,10 @@ class DistributedTest: @require_backend_is_available({"gloo"}) def test_scatter_object_list(self): src_rank = 0 - collectives_object_test_list = create_collectives_object_test_list() scatter_list = ( - collectives_object_test_list + COLLECTIVES_OBJECT_TEST_LIST if self.rank == src_rank - else [None for _ in collectives_object_test_list] + else [None for _ in COLLECTIVES_OBJECT_TEST_LIST] ) world_size = dist.get_world_size() scatter_list = scatter_list[:world_size] @@ -8299,8 +8300,8 @@ class DistributedTest: dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) self.assertEqual( output_obj_list[0], - collectives_object_test_list[ - self.rank % len(collectives_object_test_list) + COLLECTIVES_OBJECT_TEST_LIST[ + self.rank % len(COLLECTIVES_OBJECT_TEST_LIST) ], ) # Ensure errors are raised upon incorrect arguments. @@ -9986,7 +9987,7 @@ class DistributedTest: "Only Nccl & Gloo backend support DistributedDataParallel", ) def test_sync_bn_logged(self): - model = BatchNormNet() + model = BN_NET rank = self.rank # single gpu training setup model_gpu = model.cuda(rank)