diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index d15cfb1c75b3..fd9e7594828d 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -4280,11 +4280,10 @@ class NCCLTraceTestBase(MultiProcessTestCase): test_name: str, file_name: str, parent_pipe, - seed: int, **kwargs, ) -> None: cls.parent = parent_conn - super()._run(rank, test_name, file_name, parent_pipe, seed) + super()._run(rank, test_name, file_name, parent_pipe) @property def local_device(self): diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index f128f9c7eec3..f42aa7f8f436 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -27,9 +27,6 @@ from torch.testing._internal.jit_utils import ( ) -assert GRAPH_EXECUTOR is not None - - @unittest.skipIf( GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" ) diff --git a/test/test_cpp_api_parity.py b/test/test_cpp_api_parity.py index 480df4780121..2193243b751e 100644 --- a/test/test_cpp_api_parity.py +++ b/test/test_cpp_api_parity.py @@ -35,11 +35,6 @@ class TestCppApiParity(common.TestCase): functional_test_params_map = {} -if __name__ == "__main__": - # The value of the SEED depends on command line arguments so make sure they're parsed - # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn - common.parse_cmd_line_args() - expected_test_params_dicts = [] for test_params_dicts, test_instance_class in [ diff --git a/test/test_expanded_weights.py b/test/test_expanded_weights.py index 3696a1c43f43..02bf6d776568 100644 --- a/test/test_expanded_weights.py +++ b/test/test_expanded_weights.py @@ -1008,13 +1008,6 @@ def filter_supported_tests(t): return True -if __name__ == "__main__": - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of the SEED depends on command line arguments so make sure they're parsed - # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn - parse_cmd_line_args() - # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests # These currently use the legacy nn tests supported_tests = [ diff --git a/test/test_jit.py b/test/test_jit.py index 814b48449b14..c86fb111bfb8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3,13 +3,6 @@ import torch -if __name__ == '__main__': - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR and SEED depend on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - # This is how we include tests located in test/jit/... # They are included here so that they are invoked when you call `test_jit.py`, # do not run these test files directly. @@ -104,7 +97,7 @@ import torch.nn.functional as F from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ - GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ + suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -165,7 +158,6 @@ def doAutodiffCheck(testname): if "test_t_" in testname or testname == "test_t": return False - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: return False @@ -209,7 +201,6 @@ def doAutodiffCheck(testname): return testname not in test_exceptions -assert GRAPH_EXECUTOR # TODO: enable TE in PE when all tests are fixed torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index dcdf78ff4b89..b3cf4d9bee8f 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -5,17 +5,12 @@ from torch.cuda.amp import autocast from typing import Optional import unittest +from test_jit import JitTestCase from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo from torch.testing import FileCheck from jit.test_models import MnistNet -if __name__ == '__main__': - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - -from test_jit import JitTestCase TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() @skipIfTorchDynamo("Not a TorchDynamo suitable test") diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 5446770695c4..1ac7803a9d46 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -9,13 +9,6 @@ import torch.nn.functional as F from torch.testing import FileCheck from unittest import skipIf -if __name__ == "__main__": - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ diff --git a/test/test_jit_fuser_legacy.py b/test/test_jit_fuser_legacy.py index 4100bcc3e182..3bd8c9497ce0 100644 --- a/test/test_jit_fuser_legacy.py +++ b/test/test_jit_fuser_legacy.py @@ -2,14 +2,6 @@ import sys sys.argv.append("--jit-executor=legacy") - -if __name__ == "__main__": - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - from test_jit_fuser import * # noqa: F403 if __name__ == '__main__': diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 16645422e080..8d3a8090c67a 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -22,13 +22,6 @@ from torch.testing import FileCheck torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) -if __name__ == "__main__": - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - from itertools import combinations, permutations, product from textwrap import dedent diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 480b57a55bd4..5576f1645349 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -2,14 +2,7 @@ import sys sys.argv.append("--jit-executor=legacy") -from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests - -if __name__ == '__main__': - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - -from test_jit import * # noqa: F403, F401 +from test_jit import * # noqa: F403 if __name__ == '__main__': run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index 2a5b4f6421c9..a09404c40a1e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7643,13 +7643,6 @@ def add_test(test, decorator=None): else: add(cuda_test_name, with_tf32_off) -if __name__ == '__main__': - from torch.testing._internal.common_utils import parse_cmd_line_args - - # The value of the SEED depends on command line arguments so make sure they're parsed - # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn - parse_cmd_line_args() - for test_params in module_tests + get_new_module_tests(): # TODO: CUDA is not implemented yet if 'constructor' not in test_params: diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 9e6486a671ad..af1aafd3871a 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -32,7 +32,6 @@ import torch.nn as nn from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory from torch._logging._internal import trace_log -from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( FILE_SCHEMA, find_free_port, @@ -672,7 +671,6 @@ class MultiProcessTestCase(TestCase): if methodName != "runTest": method_name = methodName super().__init__(method_name) - self.seed = None try: fn = getattr(self, method_name) setattr(self, method_name, self.join_or_run(fn)) @@ -717,20 +715,13 @@ class MultiProcessTestCase(TestCase): def _start_processes(self, proc) -> None: self.processes = [] - assert common_utils.SEED is not None for rank in range(int(self.world_size)): parent_conn, child_conn = torch.multiprocessing.Pipe() process = proc( target=self.__class__._run, name="process " + str(rank), - args=( - rank, - self._current_test_name(), - self.file_name, - child_conn, - ), + args=(rank, self._current_test_name(), self.file_name, child_conn), kwargs={ - "seed": common_utils.SEED, "fake_pg": getattr(self, "fake_pg", False), }, ) @@ -784,12 +775,11 @@ class MultiProcessTestCase(TestCase): @classmethod def _run( - cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs + cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs ) -> None: self = cls(test_name) self.rank = rank self.file_name = file_name - self.seed = seed self.run_test(test_name, parent_pipe) def run_test(self, test_name: str, parent_pipe) -> None: @@ -808,9 +798,6 @@ class MultiProcessTestCase(TestCase): # Show full C++ stacktraces when a Python error originating from C++ is raised. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" - if self.seed is not None: - common_utils.set_rng_seed(self.seed) - # self.id() == e.g. '__main__.TestDistributed.test_get_rank' # We're retrieving a corresponding test and executing it. try: @@ -1548,7 +1535,7 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase): @classmethod def _run( - cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs + cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs ) -> None: trace_log.addHandler(logging.NullHandler()) @@ -1556,7 +1543,6 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase): self = cls(test_name) self.rank = rank self.file_name = file_name - self.seed = seed self.run_test(test_name, parent_pipe) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 626a9b8494e4..a9e24eb90ef8 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -57,7 +57,6 @@ from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_utils import ( FILE_SCHEMA, get_cycles_per_ms, - set_rng_seed, TEST_CUDA, TEST_HPU, TEST_XPU, @@ -1181,7 +1180,7 @@ class FSDPTest(MultiProcessTestCase): return run_subtests(self, *args, **kwargs) @classmethod - def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): # type: ignore[override] + def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override] self = cls(test_name) self.rank = rank self.file_name = file_name @@ -1227,7 +1226,6 @@ class FSDPTest(MultiProcessTestCase): dist.barrier(device_ids=device_ids) torch._dynamo.reset() - set_rng_seed(seed) self.run_test(test_name, pipe) torch._dynamo.reset() diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index b42114d7d0cd..135cc6a7bd66 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -15,7 +15,6 @@ import torch.cuda import torch.nn as nn import torch.nn.functional as F from torch.nn import _reduction as _Reduction -from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater @@ -1079,7 +1078,6 @@ def single_batch_reference_fn(input, parameters, module): def get_new_module_tests(): - assert common_utils.SEED is not None, "Make sure the seed is set before calling get_new_module_tests()" new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 32878eaa9b33..384db57e92ec 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -104,31 +104,6 @@ except ImportError: MI300_ARCH = ("gfx942",) -class ProfilingMode(Enum): - LEGACY = 1 - SIMPLE = 2 - PROFILING = 3 - -# Set by parse_cmd_line_args() if called -CI_FUNCTORCH_ROOT = "" -CI_PT_ROOT = "" -CI_TEST_PREFIX = "" -DISABLED_TESTS_FILE = "" -GRAPH_EXECUTOR : Optional[ProfilingMode] = None -LOG_SUFFIX = "" -PYTEST_SINGLE_TEST = "" -REPEAT_COUNT = 0 -RERUN_DISABLED_TESTS = False -RUN_PARALLEL = 0 -SEED : Optional[int] = None -SHOWLOCALS = False -SLOW_TESTS_FILE = "" -TEST_BAILOUTS = False -TEST_DISCOVER = False -TEST_IN_SUBPROCESS = False -TEST_SAVE_XML = "" -UNITTEST_ARGS : list[str] = [] -USE_PYTEST = False def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -863,6 +838,11 @@ class decorateIf(_TestParametrizer): yield (test_wrapper, test_name, {}, decorator_fn) +class ProfilingMode(Enum): + LEGACY = 1 + SIMPLE = 2 + PROFILING = 3 + def cppProfilingFlagsToProfilingMode(): old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -881,7 +861,6 @@ def cppProfilingFlagsToProfilingMode(): def enable_profiling_mode_for_profiling_tests(): old_prof_exec_state = False old_prof_mode_state = False - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -916,7 +895,6 @@ meth_call = torch._C.ScriptMethod.__call__ def prof_callable(callable, *args, **kwargs): if 'profile_and_replay' in kwargs: del kwargs['profile_and_replay'] - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: with enable_profiling_mode_for_profiling_tests(): callable(*args, **kwargs) @@ -946,93 +924,72 @@ def _get_test_report_path(): test_source = override if override is not None else 'python-unittest' return os.path.join('test-reports', test_source) -def parse_cmd_line_args(): - global CI_FUNCTORCH_ROOT - global CI_PT_ROOT - global CI_TEST_PREFIX - global DISABLED_TESTS_FILE - global GRAPH_EXECUTOR - global LOG_SUFFIX - global PYTEST_SINGLE_TEST - global REPEAT_COUNT - global RERUN_DISABLED_TESTS - global RUN_PARALLEL - global SEED - global SHOWLOCALS - global SLOW_TESTS_FILE - global TEST_BAILOUTS - global TEST_DISCOVER - global TEST_IN_SUBPROCESS - global TEST_SAVE_XML - global UNITTEST_ARGS - global USE_PYTEST - - is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") - parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) - parser.add_argument('--subprocess', action='store_true', - help='whether to run each test in a subprocess') - parser.add_argument('--seed', type=int, default=1234) - parser.add_argument('--accept', action='store_true') - parser.add_argument('--jit-executor', '--jit_executor', type=str) - parser.add_argument('--repeat', type=int, default=1) - parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') - parser.add_argument('--use-pytest', action='store_true') - parser.add_argument('--save-xml', nargs='?', type=str, - const=_get_test_report_path(), - default=_get_test_report_path() if IS_CI else None) - parser.add_argument('--discover-tests', action='store_true') - parser.add_argument('--log-suffix', type=str, default="") - parser.add_argument('--run-parallel', type=int, default=1) - parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) - parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) - parser.add_argument('--rerun-disabled-tests', action='store_true') - parser.add_argument('--pytest-single-test', type=str, nargs=1) - parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) +is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") +parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) +parser.add_argument('--subprocess', action='store_true', + help='whether to run each test in a subprocess') +parser.add_argument('--seed', type=int, default=1234) +parser.add_argument('--accept', action='store_true') +parser.add_argument('--jit-executor', '--jit_executor', type=str) +parser.add_argument('--repeat', type=int, default=1) +parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') +parser.add_argument('--use-pytest', action='store_true') +parser.add_argument('--save-xml', nargs='?', type=str, + const=_get_test_report_path(), + default=_get_test_report_path() if IS_CI else None) +parser.add_argument('--discover-tests', action='store_true') +parser.add_argument('--log-suffix', type=str, default="") +parser.add_argument('--run-parallel', type=int, default=1) +parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) +parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) +parser.add_argument('--rerun-disabled-tests', action='store_true') +parser.add_argument('--pytest-single-test', type=str, nargs=1) +parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) # Only run when -h or --help flag is active to display both unittest and parser help messages. - def run_unittest_help(argv): - unittest.main(argv=argv) +def run_unittest_help(argv): + unittest.main(argv=argv) - if '-h' in sys.argv or '--help' in sys.argv: - help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) - help_thread.start() - help_thread.join() +if '-h' in sys.argv or '--help' in sys.argv: + help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) + help_thread.start() + help_thread.join() - args, remaining = parser.parse_known_args() - if args.jit_executor == 'legacy': - GRAPH_EXECUTOR = ProfilingMode.LEGACY - elif args.jit_executor == 'profiling': - GRAPH_EXECUTOR = ProfilingMode.PROFILING - elif args.jit_executor == 'simple': - GRAPH_EXECUTOR = ProfilingMode.SIMPLE - else: - # infer flags based on the default settings - GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() +args, remaining = parser.parse_known_args() +if args.jit_executor == 'legacy': + GRAPH_EXECUTOR = ProfilingMode.LEGACY +elif args.jit_executor == 'profiling': + GRAPH_EXECUTOR = ProfilingMode.PROFILING +elif args.jit_executor == 'simple': + GRAPH_EXECUTOR = ProfilingMode.SIMPLE +else: + # infer flags based on the default settings + GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() - RERUN_DISABLED_TESTS = args.rerun_disabled_tests +RERUN_DISABLED_TESTS = args.rerun_disabled_tests - SLOW_TESTS_FILE = args.import_slow_tests - DISABLED_TESTS_FILE = args.import_disabled_tests - LOG_SUFFIX = args.log_suffix - RUN_PARALLEL = args.run_parallel - TEST_BAILOUTS = args.test_bailouts - USE_PYTEST = args.use_pytest - PYTEST_SINGLE_TEST = args.pytest_single_test - TEST_DISCOVER = args.discover_tests - TEST_IN_SUBPROCESS = args.subprocess - TEST_SAVE_XML = args.save_xml - REPEAT_COUNT = args.repeat - SEED = args.seed - SHOWLOCALS = args.showlocals - if not getattr(expecttest, "ACCEPT", False): - expecttest.ACCEPT = args.accept - UNITTEST_ARGS = [sys.argv[0]] + remaining - torch.manual_seed(SEED) +SLOW_TESTS_FILE = args.import_slow_tests +DISABLED_TESTS_FILE = args.import_disabled_tests +LOG_SUFFIX = args.log_suffix +RUN_PARALLEL = args.run_parallel +TEST_BAILOUTS = args.test_bailouts +USE_PYTEST = args.use_pytest +PYTEST_SINGLE_TEST = args.pytest_single_test +TEST_DISCOVER = args.discover_tests +TEST_IN_SUBPROCESS = args.subprocess +TEST_SAVE_XML = args.save_xml +REPEAT_COUNT = args.repeat +SEED = args.seed +SHOWLOCALS = args.showlocals +if not getattr(expecttest, "ACCEPT", False): + expecttest.ACCEPT = args.accept +UNITTEST_ARGS = [sys.argv[0]] + remaining +torch.manual_seed(SEED) # CI Prefix path used only on CI environment - CI_TEST_PREFIX = str(Path(os.getcwd())) - CI_PT_ROOT = str(Path(os.getcwd()).parent) - CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) +CI_TEST_PREFIX = str(Path(os.getcwd())) +CI_PT_ROOT = str(Path(os.getcwd()).parent) +CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) def wait_for_process(p, timeout=None): try: @@ -1181,9 +1138,7 @@ def lint_test_case_extension(suite): return succeed -def get_report_path(argv=None, pytest=False): - if argv is None: - argv = UNITTEST_ARGS +def get_report_path(argv=UNITTEST_ARGS, pytest=False): test_filename = sanitize_test_filename(argv[0]) test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = os.path.join(test_report_path, test_filename) @@ -1234,11 +1189,7 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]: return test_collector_plugin.tests -def run_tests(argv=None): - parse_cmd_line_args() - if argv is None: - argv = UNITTEST_ARGS - +def run_tests(argv=UNITTEST_ARGS): # import test files. if SLOW_TESTS_FILE: if os.path.exists(SLOW_TESTS_FILE): @@ -1804,7 +1755,6 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe if not isinstance(fn, type): @wraps(fn) def wrapper(*args, **kwargs): - assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.LEGACY: raise unittest.SkipTest(msg) else: @@ -2435,19 +2385,7 @@ def get_function_arglist(func): return inspect.getfullargspec(func).args -def set_rng_seed(seed=None): - if seed is None: - if SEED is not None: - seed = SEED - else: - # Can't assert here: this function is called by TestCase.setUp() and some out of tree tests inherit from that class. - # So just print a warning and hardcode the seed. - seed = 1234 - msg = ("set_rng_seed() was called without providing a seed and the command line " - f"arguments haven't been parsed so the seed will be set to {seed}. " - "To remove this warning make sure your test is run via run_tests() or " - "parse_cmd_line_args() is called before set_rng_seed() is called.") - warnings.warn(msg) +def set_rng_seed(seed): torch.manual_seed(seed) random.seed(seed) if TEST_NUMPY: @@ -3470,7 +3408,7 @@ class TestCase(expecttest.TestCase): def setUp(self): check_if_enable(self) - set_rng_seed() + set_rng_seed(SEED) # Save global check sparse tensor invariants state that can be # restored from tearDown: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 28cd7efc3226..28b761a37d58 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -137,18 +137,16 @@ class Foo: f = Foo(10) f.bar = 1 +foo_cpu_tensor = Foo(torch.randn(3, 3)) -# Defer instantiation until the seed is set so that randn() returns the same -# values in all processes. -def create_collectives_object_test_list(): - return [ - {"key1": 3, "key2": 4, "key3": {"nested": True}}, - f, - Foo(torch.randn(3, 3)), - "foo", - [1, 2, True, "string", [4, 5, "nested"]], - ] +COLLECTIVES_OBJECT_TEST_LIST = [ + {"key1": 3, "key2": 4, "key3": {"nested": True}}, + f, + foo_cpu_tensor, + "foo", + [1, 2, True, "string", [4, 5, "nested"]], +] # Allowlist of distributed backends where profiling collectives is supported. PROFILING_SUPPORTED_BACKENDS = [ @@ -398,6 +396,12 @@ class ControlFlowToyModel(nn.Module): return F.relu(self.lin1(x)) +DDP_NET = Net() +BN_NET = BatchNormNet() +BN_NET_NO_AFFINE = BatchNormNet(affine=False) +ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) + + def get_timeout(test_id): test_name = test_id.split(".")[-1] if test_name in CUSTOMIZED_TIMEOUT: @@ -590,13 +594,12 @@ class TestDistBackend(MultiProcessTestCase): return False @classmethod - def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): + def _run(cls, rank, test_name, file_name, pipe, **kwargs): if BACKEND == "nccl" and not torch.cuda.is_available(): sys.exit(TEST_SKIPS["no_cuda"].exit_code) self = cls(test_name) self.rank = rank self.file_name = file_name - self.seed = seed if torch.cuda.is_available() and torch.cuda.device_count() < int( self.world_size @@ -4284,7 +4287,7 @@ class DistributedTest: # as baseline # cpu training setup - model = Net() + model = DDP_NET # single gpu training setup model_gpu = copy.deepcopy(model) @@ -4339,7 +4342,7 @@ class DistributedTest: _group, _group_id, rank = self._init_global_test() # cpu training setup - model_base = Net() + model_base = DDP_NET # DDP-CPU training setup model_DDP = copy.deepcopy(model_base) @@ -5488,7 +5491,7 @@ class DistributedTest: def _test_DistributedDataParallel_with_amp(self, grad_is_view=False): torch.manual_seed(31415) # Creates model and optimizer in default precision - model = Net().cuda() + model = copy.deepcopy(DDP_NET).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.03) # Creates a GradScaler once at the beginning of training. @@ -5573,7 +5576,7 @@ class DistributedTest: # as baseline # cpu training setup - model = BatchNormNet() if affine else BatchNormNet(affine=False) + model = BN_NET if affine else BN_NET_NO_AFFINE # single gpu training setup model_gpu = copy.deepcopy(model) @@ -5623,7 +5626,6 @@ class DistributedTest: def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view): learning_rate = 0.03 - DDP_NET = Net() net = torch.nn.parallel.DistributedDataParallel( copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank], @@ -5690,7 +5692,7 @@ class DistributedTest: learning_rate = 0.03 net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( - Net().cuda(), device_ids=[self.rank] + copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank] ) averager = create_averager() @@ -5840,7 +5842,7 @@ class DistributedTest: bs_offset = int(rank * 2) global_bs = int(num_processes * 2) - model = nn.SyncBatchNorm(2, momentum=0.99) + model = ONLY_SBN_NET model_gpu = copy.deepcopy(model).cuda(rank) model_DDP = nn.parallel.DistributedDataParallel( model_gpu, device_ids=[rank] @@ -6050,7 +6052,6 @@ class DistributedTest: def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( self, ): - ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) _group, _group_id, rank = self._init_global_test() model = nn.parallel.DistributedDataParallel( ONLY_SBN_NET.cuda(rank), device_ids=[rank] @@ -6118,7 +6119,7 @@ class DistributedTest: def test_DistributedDataParallel_SyncBatchNorm_half(self): _group, _group_id, rank = self._init_global_test() - model = BatchNormNet() + model = copy.deepcopy(BN_NET) model = model.half() model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.parallel.DistributedDataParallel( @@ -6134,7 +6135,7 @@ class DistributedTest: def _test_ddp_logging_data(self, is_gpu): rank = dist.get_rank() - model_DDP = Net() + model_DDP = copy.deepcopy(DDP_NET) if is_gpu: model_DDP = nn.parallel.DistributedDataParallel( model_DDP.cuda(rank), device_ids=[rank] @@ -6410,7 +6411,7 @@ class DistributedTest: BACKEND == "nccl", "nccl does not support DDP on CPU models" ) def test_static_graph_api_cpu(self): - model_DDP = nn.parallel.DistributedDataParallel(Net()) + model_DDP = nn.parallel.DistributedDataParallel(DDP_NET) expected_err = "should be called before training loop starts" with self.assertRaisesRegex(RuntimeError, expected_err): local_bs = 2 @@ -6643,7 +6644,7 @@ class DistributedTest: def _test_allgather_object(self, subgroup=None): # Only set device for NCCL backend since it must use GPUs. - gather_objects = create_collectives_object_test_list() + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() backend = os.environ["BACKEND"] if backend == "nccl": @@ -6687,7 +6688,7 @@ class DistributedTest: def _test_gather_object(self, pg=None): # Ensure stateful objects can be gathered - gather_objects = create_collectives_object_test_list() + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() my_rank = dist.get_rank(pg) backend = os.environ["BACKEND"] @@ -7257,7 +7258,7 @@ class DistributedTest: return x torch.cuda.set_device(self.rank) - model_bn = BatchNormNet() + model_bn = BN_NET model_bn = nn.SyncBatchNorm.convert_sync_batchnorm( copy.deepcopy(model_bn) ).cuda(self.rank) @@ -7553,7 +7554,7 @@ class DistributedTest: loss.backward() def _test_broadcast_object_list(self, group=None): - gather_objects = create_collectives_object_test_list() + gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() # Only set device for NCCL backend since it must use GPUs. # Case where rank != GPU device. @@ -8277,11 +8278,10 @@ class DistributedTest: @require_backend_is_available({"gloo"}) def test_scatter_object_list(self): src_rank = 0 - collectives_object_test_list = create_collectives_object_test_list() scatter_list = ( - collectives_object_test_list + COLLECTIVES_OBJECT_TEST_LIST if self.rank == src_rank - else [None for _ in collectives_object_test_list] + else [None for _ in COLLECTIVES_OBJECT_TEST_LIST] ) world_size = dist.get_world_size() scatter_list = scatter_list[:world_size] @@ -8294,8 +8294,8 @@ class DistributedTest: dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) self.assertEqual( output_obj_list[0], - collectives_object_test_list[ - self.rank % len(collectives_object_test_list) + COLLECTIVES_OBJECT_TEST_LIST[ + self.rank % len(COLLECTIVES_OBJECT_TEST_LIST) ], ) # Ensure errors are raised upon incorrect arguments. @@ -9981,7 +9981,7 @@ class DistributedTest: "Only Nccl & Gloo backend support DistributedDataParallel", ) def test_sync_bn_logged(self): - model = BatchNormNet() + model = BN_NET rank = self.rank # single gpu training setup model_gpu = model.cuda(rank)