Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"

This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3.

Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
This commit is contained in:
PyTorch MergeBot
2025-08-04 20:37:39 +00:00
parent d4109a0f99
commit 356ac3103a
16 changed files with 109 additions and 255 deletions

View File

@ -4280,11 +4280,10 @@ class NCCLTraceTestBase(MultiProcessTestCase):
test_name: str, test_name: str,
file_name: str, file_name: str,
parent_pipe, parent_pipe,
seed: int,
**kwargs, **kwargs,
) -> None: ) -> None:
cls.parent = parent_conn cls.parent = parent_conn
super()._run(rank, test_name, file_name, parent_pipe, seed) super()._run(rank, test_name, file_name, parent_pipe)
@property @property
def local_device(self): def local_device(self):

View File

@ -27,9 +27,6 @@ from torch.testing._internal.jit_utils import (
) )
assert GRAPH_EXECUTOR is not None
@unittest.skipIf( @unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
) )

View File

@ -35,11 +35,6 @@ class TestCppApiParity(common.TestCase):
functional_test_params_map = {} functional_test_params_map = {}
if __name__ == "__main__":
# The value of the SEED depends on command line arguments so make sure they're parsed
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
common.parse_cmd_line_args()
expected_test_params_dicts = [] expected_test_params_dicts = []
for test_params_dicts, test_instance_class in [ for test_params_dicts, test_instance_class in [

View File

@ -1008,13 +1008,6 @@ def filter_supported_tests(t):
return True return True
if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of the SEED depends on command line arguments so make sure they're parsed
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
parse_cmd_line_args()
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
# These currently use the legacy nn tests # These currently use the legacy nn tests
supported_tests = [ supported_tests = [

View File

@ -3,13 +3,6 @@
import torch import torch
if __name__ == '__main__':
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR and SEED depend on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
# This is how we include tests located in test/jit/... # This is how we include tests located in test/jit/...
# They are included here so that they are invoked when you call `test_jit.py`, # They are included here so that they are invoked when you call `test_jit.py`,
# do not run these test files directly. # do not run these test files directly.
@ -104,7 +97,7 @@ import torch.nn.functional as F
from torch.testing._internal import jit_utils from torch.testing._internal import jit_utils
from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_jit import check_against_reference
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
skipIfCrossRef, skipIfTorchDynamo skipIfCrossRef, skipIfTorchDynamo
@ -165,7 +158,6 @@ def doAutodiffCheck(testname):
if "test_t_" in testname or testname == "test_t": if "test_t_" in testname or testname == "test_t":
return False return False
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
return False return False
@ -209,7 +201,6 @@ def doAutodiffCheck(testname):
return testname not in test_exceptions return testname not in test_exceptions
assert GRAPH_EXECUTOR
# TODO: enable TE in PE when all tests are fixed # TODO: enable TE in PE when all tests are fixed
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)

View File

@ -5,17 +5,12 @@ from torch.cuda.amp import autocast
from typing import Optional from typing import Optional
import unittest import unittest
from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
from torch.testing import FileCheck from torch.testing import FileCheck
from jit.test_models import MnistNet from jit.test_models import MnistNet
if __name__ == '__main__':
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from test_jit import JitTestCase
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
@skipIfTorchDynamo("Not a TorchDynamo suitable test") @skipIfTorchDynamo("Not a TorchDynamo suitable test")

View File

@ -9,13 +9,6 @@ import torch.nn.functional as F
from torch.testing import FileCheck from torch.testing import FileCheck
from unittest import skipIf from unittest import skipIf
if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \

View File

@ -2,14 +2,6 @@
import sys import sys
sys.argv.append("--jit-executor=legacy") sys.argv.append("--jit-executor=legacy")
if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from test_jit_fuser import * # noqa: F403 from test_jit_fuser import * # noqa: F403
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -22,13 +22,6 @@ from torch.testing import FileCheck
torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_executor(True)
torch._C._get_graph_executor_optimize(True) torch._C._get_graph_executor_optimize(True)
if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from itertools import combinations, permutations, product from itertools import combinations, permutations, product
from textwrap import dedent from textwrap import dedent

View File

@ -2,14 +2,7 @@
import sys import sys
sys.argv.append("--jit-executor=legacy") sys.argv.append("--jit-executor=legacy")
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests from test_jit import * # noqa: F403
if __name__ == '__main__':
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()
from test_jit import * # noqa: F403, F401
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -7643,13 +7643,6 @@ def add_test(test, decorator=None):
else: else:
add(cuda_test_name, with_tf32_off) add(cuda_test_name, with_tf32_off)
if __name__ == '__main__':
from torch.testing._internal.common_utils import parse_cmd_line_args
# The value of the SEED depends on command line arguments so make sure they're parsed
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
parse_cmd_line_args()
for test_params in module_tests + get_new_module_tests(): for test_params in module_tests + get_new_module_tests():
# TODO: CUDA is not implemented yet # TODO: CUDA is not implemented yet
if 'constructor' not in test_params: if 'constructor' not in test_params:

View File

@ -32,7 +32,6 @@ import torch.nn as nn
from torch._C._autograd import DeviceType from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory from torch._C._distributed_c10d import _SymmetricMemory
from torch._logging._internal import trace_log from torch._logging._internal import trace_log
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
FILE_SCHEMA, FILE_SCHEMA,
find_free_port, find_free_port,
@ -672,7 +671,6 @@ class MultiProcessTestCase(TestCase):
if methodName != "runTest": if methodName != "runTest":
method_name = methodName method_name = methodName
super().__init__(method_name) super().__init__(method_name)
self.seed = None
try: try:
fn = getattr(self, method_name) fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn)) setattr(self, method_name, self.join_or_run(fn))
@ -717,20 +715,13 @@ class MultiProcessTestCase(TestCase):
def _start_processes(self, proc) -> None: def _start_processes(self, proc) -> None:
self.processes = [] self.processes = []
assert common_utils.SEED is not None
for rank in range(int(self.world_size)): for rank in range(int(self.world_size)):
parent_conn, child_conn = torch.multiprocessing.Pipe() parent_conn, child_conn = torch.multiprocessing.Pipe()
process = proc( process = proc(
target=self.__class__._run, target=self.__class__._run,
name="process " + str(rank), name="process " + str(rank),
args=( args=(rank, self._current_test_name(), self.file_name, child_conn),
rank,
self._current_test_name(),
self.file_name,
child_conn,
),
kwargs={ kwargs={
"seed": common_utils.SEED,
"fake_pg": getattr(self, "fake_pg", False), "fake_pg": getattr(self, "fake_pg", False),
}, },
) )
@ -784,12 +775,11 @@ class MultiProcessTestCase(TestCase):
@classmethod @classmethod
def _run( def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
) -> None: ) -> None:
self = cls(test_name) self = cls(test_name)
self.rank = rank self.rank = rank
self.file_name = file_name self.file_name = file_name
self.seed = seed
self.run_test(test_name, parent_pipe) self.run_test(test_name, parent_pipe)
def run_test(self, test_name: str, parent_pipe) -> None: def run_test(self, test_name: str, parent_pipe) -> None:
@ -808,9 +798,6 @@ class MultiProcessTestCase(TestCase):
# Show full C++ stacktraces when a Python error originating from C++ is raised. # Show full C++ stacktraces when a Python error originating from C++ is raised.
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
if self.seed is not None:
common_utils.set_rng_seed(self.seed)
# self.id() == e.g. '__main__.TestDistributed.test_get_rank' # self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retrieving a corresponding test and executing it. # We're retrieving a corresponding test and executing it.
try: try:
@ -1548,7 +1535,7 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
@classmethod @classmethod
def _run( def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
) -> None: ) -> None:
trace_log.addHandler(logging.NullHandler()) trace_log.addHandler(logging.NullHandler())
@ -1556,7 +1543,6 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
self = cls(test_name) self = cls(test_name)
self.rank = rank self.rank = rank
self.file_name = file_name self.file_name = file_name
self.seed = seed
self.run_test(test_name, parent_pipe) self.run_test(test_name, parent_pipe)

View File

@ -57,7 +57,6 @@ from torch.testing._internal.common_distributed import (
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
FILE_SCHEMA, FILE_SCHEMA,
get_cycles_per_ms, get_cycles_per_ms,
set_rng_seed,
TEST_CUDA, TEST_CUDA,
TEST_HPU, TEST_HPU,
TEST_XPU, TEST_XPU,
@ -1181,7 +1180,7 @@ class FSDPTest(MultiProcessTestCase):
return run_subtests(self, *args, **kwargs) return run_subtests(self, *args, **kwargs)
@classmethod @classmethod
def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): # type: ignore[override] def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override]
self = cls(test_name) self = cls(test_name)
self.rank = rank self.rank = rank
self.file_name = file_name self.file_name = file_name
@ -1227,7 +1226,6 @@ class FSDPTest(MultiProcessTestCase):
dist.barrier(device_ids=device_ids) dist.barrier(device_ids=device_ids)
torch._dynamo.reset() torch._dynamo.reset()
set_rng_seed(seed)
self.run_test(test_name, pipe) self.run_test(test_name, pipe)
torch._dynamo.reset() torch._dynamo.reset()

View File

@ -15,7 +15,6 @@ import torch.cuda
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import _reduction as _Reduction from torch.nn import _reduction as _Reduction
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
@ -1079,7 +1078,6 @@ def single_batch_reference_fn(input, parameters, module):
def get_new_module_tests(): def get_new_module_tests():
assert common_utils.SEED is not None, "Make sure the seed is set before calling get_new_module_tests()"
new_module_tests = [ new_module_tests = [
poissonnllloss_no_reduce_test(), poissonnllloss_no_reduce_test(),
bceloss_no_reduce_test(), bceloss_no_reduce_test(),

View File

@ -104,31 +104,6 @@ except ImportError:
MI300_ARCH = ("gfx942",) MI300_ARCH = ("gfx942",)
class ProfilingMode(Enum):
LEGACY = 1
SIMPLE = 2
PROFILING = 3
# Set by parse_cmd_line_args() if called
CI_FUNCTORCH_ROOT = ""
CI_PT_ROOT = ""
CI_TEST_PREFIX = ""
DISABLED_TESTS_FILE = ""
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
LOG_SUFFIX = ""
PYTEST_SINGLE_TEST = ""
REPEAT_COUNT = 0
RERUN_DISABLED_TESTS = False
RUN_PARALLEL = 0
SEED : Optional[int] = None
SHOWLOCALS = False
SLOW_TESTS_FILE = ""
TEST_BAILOUTS = False
TEST_DISCOVER = False
TEST_IN_SUBPROCESS = False
TEST_SAVE_XML = ""
UNITTEST_ARGS : list[str] = []
USE_PYTEST = False
def freeze_rng_state(*args, **kwargs): def freeze_rng_state(*args, **kwargs):
return torch.testing._utils.freeze_rng_state(*args, **kwargs) return torch.testing._utils.freeze_rng_state(*args, **kwargs)
@ -863,6 +838,11 @@ class decorateIf(_TestParametrizer):
yield (test_wrapper, test_name, {}, decorator_fn) yield (test_wrapper, test_name, {}, decorator_fn)
class ProfilingMode(Enum):
LEGACY = 1
SIMPLE = 2
PROFILING = 3
def cppProfilingFlagsToProfilingMode(): def cppProfilingFlagsToProfilingMode():
old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
@ -881,7 +861,6 @@ def cppProfilingFlagsToProfilingMode():
def enable_profiling_mode_for_profiling_tests(): def enable_profiling_mode_for_profiling_tests():
old_prof_exec_state = False old_prof_exec_state = False
old_prof_mode_state = False old_prof_mode_state = False
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.PROFILING: if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
@ -916,7 +895,6 @@ meth_call = torch._C.ScriptMethod.__call__
def prof_callable(callable, *args, **kwargs): def prof_callable(callable, *args, **kwargs):
if 'profile_and_replay' in kwargs: if 'profile_and_replay' in kwargs:
del kwargs['profile_and_replay'] del kwargs['profile_and_replay']
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.PROFILING: if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
callable(*args, **kwargs) callable(*args, **kwargs)
@ -946,93 +924,72 @@ def _get_test_report_path():
test_source = override if override is not None else 'python-unittest' test_source = override if override is not None else 'python-unittest'
return os.path.join('test-reports', test_source) return os.path.join('test-reports', test_source)
def parse_cmd_line_args(): is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
global CI_FUNCTORCH_ROOT parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
global CI_PT_ROOT parser.add_argument('--subprocess', action='store_true',
global CI_TEST_PREFIX help='whether to run each test in a subprocess')
global DISABLED_TESTS_FILE parser.add_argument('--seed', type=int, default=1234)
global GRAPH_EXECUTOR parser.add_argument('--accept', action='store_true')
global LOG_SUFFIX parser.add_argument('--jit-executor', '--jit_executor', type=str)
global PYTEST_SINGLE_TEST parser.add_argument('--repeat', type=int, default=1)
global REPEAT_COUNT parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
global RERUN_DISABLED_TESTS parser.add_argument('--use-pytest', action='store_true')
global RUN_PARALLEL parser.add_argument('--save-xml', nargs='?', type=str,
global SEED const=_get_test_report_path(),
global SHOWLOCALS default=_get_test_report_path() if IS_CI else None)
global SLOW_TESTS_FILE parser.add_argument('--discover-tests', action='store_true')
global TEST_BAILOUTS parser.add_argument('--log-suffix', type=str, default="")
global TEST_DISCOVER parser.add_argument('--run-parallel', type=int, default=1)
global TEST_IN_SUBPROCESS parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
global TEST_SAVE_XML parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
global UNITTEST_ARGS parser.add_argument('--rerun-disabled-tests', action='store_true')
global USE_PYTEST parser.add_argument('--pytest-single-test', type=str, nargs=1)
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
parser.add_argument('--subprocess', action='store_true',
help='whether to run each test in a subprocess')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--accept', action='store_true')
parser.add_argument('--jit-executor', '--jit_executor', type=str)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
parser.add_argument('--rerun-disabled-tests', action='store_true')
parser.add_argument('--pytest-single-test', type=str, nargs=1)
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
# Only run when -h or --help flag is active to display both unittest and parser help messages. # Only run when -h or --help flag is active to display both unittest and parser help messages.
def run_unittest_help(argv): def run_unittest_help(argv):
unittest.main(argv=argv) unittest.main(argv=argv)
if '-h' in sys.argv or '--help' in sys.argv: if '-h' in sys.argv or '--help' in sys.argv:
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
help_thread.start() help_thread.start()
help_thread.join() help_thread.join()
args, remaining = parser.parse_known_args() args, remaining = parser.parse_known_args()
if args.jit_executor == 'legacy': if args.jit_executor == 'legacy':
GRAPH_EXECUTOR = ProfilingMode.LEGACY GRAPH_EXECUTOR = ProfilingMode.LEGACY
elif args.jit_executor == 'profiling': elif args.jit_executor == 'profiling':
GRAPH_EXECUTOR = ProfilingMode.PROFILING GRAPH_EXECUTOR = ProfilingMode.PROFILING
elif args.jit_executor == 'simple': elif args.jit_executor == 'simple':
GRAPH_EXECUTOR = ProfilingMode.SIMPLE GRAPH_EXECUTOR = ProfilingMode.SIMPLE
else: else:
# infer flags based on the default settings # infer flags based on the default settings
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
RERUN_DISABLED_TESTS = args.rerun_disabled_tests RERUN_DISABLED_TESTS = args.rerun_disabled_tests
SLOW_TESTS_FILE = args.import_slow_tests SLOW_TESTS_FILE = args.import_slow_tests
DISABLED_TESTS_FILE = args.import_disabled_tests DISABLED_TESTS_FILE = args.import_disabled_tests
LOG_SUFFIX = args.log_suffix LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts TEST_BAILOUTS = args.test_bailouts
USE_PYTEST = args.use_pytest USE_PYTEST = args.use_pytest
PYTEST_SINGLE_TEST = args.pytest_single_test PYTEST_SINGLE_TEST = args.pytest_single_test
TEST_DISCOVER = args.discover_tests TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml TEST_SAVE_XML = args.save_xml
REPEAT_COUNT = args.repeat REPEAT_COUNT = args.repeat
SEED = args.seed SEED = args.seed
SHOWLOCALS = args.showlocals SHOWLOCALS = args.showlocals
if not getattr(expecttest, "ACCEPT", False): if not getattr(expecttest, "ACCEPT", False):
expecttest.ACCEPT = args.accept expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED) torch.manual_seed(SEED)
# CI Prefix path used only on CI environment # CI Prefix path used only on CI environment
CI_TEST_PREFIX = str(Path(os.getcwd())) CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent) CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
def wait_for_process(p, timeout=None): def wait_for_process(p, timeout=None):
try: try:
@ -1181,9 +1138,7 @@ def lint_test_case_extension(suite):
return succeed return succeed
def get_report_path(argv=None, pytest=False): def get_report_path(argv=UNITTEST_ARGS, pytest=False):
if argv is None:
argv = UNITTEST_ARGS
test_filename = sanitize_test_filename(argv[0]) test_filename = sanitize_test_filename(argv[0])
test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename) test_report_path = os.path.join(test_report_path, test_filename)
@ -1234,11 +1189,7 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]:
return test_collector_plugin.tests return test_collector_plugin.tests
def run_tests(argv=None): def run_tests(argv=UNITTEST_ARGS):
parse_cmd_line_args()
if argv is None:
argv = UNITTEST_ARGS
# import test files. # import test files.
if SLOW_TESTS_FILE: if SLOW_TESTS_FILE:
if os.path.exists(SLOW_TESTS_FILE): if os.path.exists(SLOW_TESTS_FILE):
@ -1804,7 +1755,6 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
if not isinstance(fn, type): if not isinstance(fn, type):
@wraps(fn) @wraps(fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.LEGACY: if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
raise unittest.SkipTest(msg) raise unittest.SkipTest(msg)
else: else:
@ -2435,19 +2385,7 @@ def get_function_arglist(func):
return inspect.getfullargspec(func).args return inspect.getfullargspec(func).args
def set_rng_seed(seed=None): def set_rng_seed(seed):
if seed is None:
if SEED is not None:
seed = SEED
else:
# Can't assert here: this function is called by TestCase.setUp() and some out of tree tests inherit from that class.
# So just print a warning and hardcode the seed.
seed = 1234
msg = ("set_rng_seed() was called without providing a seed and the command line "
f"arguments haven't been parsed so the seed will be set to {seed}. "
"To remove this warning make sure your test is run via run_tests() or "
"parse_cmd_line_args() is called before set_rng_seed() is called.")
warnings.warn(msg)
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
if TEST_NUMPY: if TEST_NUMPY:
@ -3470,7 +3408,7 @@ class TestCase(expecttest.TestCase):
def setUp(self): def setUp(self):
check_if_enable(self) check_if_enable(self)
set_rng_seed() set_rng_seed(SEED)
# Save global check sparse tensor invariants state that can be # Save global check sparse tensor invariants state that can be
# restored from tearDown: # restored from tearDown:

View File

@ -137,18 +137,16 @@ class Foo:
f = Foo(10) f = Foo(10)
f.bar = 1 f.bar = 1
foo_cpu_tensor = Foo(torch.randn(3, 3))
# Defer instantiation until the seed is set so that randn() returns the same
# values in all processes.
def create_collectives_object_test_list():
return [
{"key1": 3, "key2": 4, "key3": {"nested": True}},
f,
Foo(torch.randn(3, 3)),
"foo",
[1, 2, True, "string", [4, 5, "nested"]],
]
COLLECTIVES_OBJECT_TEST_LIST = [
{"key1": 3, "key2": 4, "key3": {"nested": True}},
f,
foo_cpu_tensor,
"foo",
[1, 2, True, "string", [4, 5, "nested"]],
]
# Allowlist of distributed backends where profiling collectives is supported. # Allowlist of distributed backends where profiling collectives is supported.
PROFILING_SUPPORTED_BACKENDS = [ PROFILING_SUPPORTED_BACKENDS = [
@ -398,6 +396,12 @@ class ControlFlowToyModel(nn.Module):
return F.relu(self.lin1(x)) return F.relu(self.lin1(x))
DDP_NET = Net()
BN_NET = BatchNormNet()
BN_NET_NO_AFFINE = BatchNormNet(affine=False)
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
def get_timeout(test_id): def get_timeout(test_id):
test_name = test_id.split(".")[-1] test_name = test_id.split(".")[-1]
if test_name in CUSTOMIZED_TIMEOUT: if test_name in CUSTOMIZED_TIMEOUT:
@ -590,13 +594,12 @@ class TestDistBackend(MultiProcessTestCase):
return False return False
@classmethod @classmethod
def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): def _run(cls, rank, test_name, file_name, pipe, **kwargs):
if BACKEND == "nccl" and not torch.cuda.is_available(): if BACKEND == "nccl" and not torch.cuda.is_available():
sys.exit(TEST_SKIPS["no_cuda"].exit_code) sys.exit(TEST_SKIPS["no_cuda"].exit_code)
self = cls(test_name) self = cls(test_name)
self.rank = rank self.rank = rank
self.file_name = file_name self.file_name = file_name
self.seed = seed
if torch.cuda.is_available() and torch.cuda.device_count() < int( if torch.cuda.is_available() and torch.cuda.device_count() < int(
self.world_size self.world_size
@ -4284,7 +4287,7 @@ class DistributedTest:
# as baseline # as baseline
# cpu training setup # cpu training setup
model = Net() model = DDP_NET
# single gpu training setup # single gpu training setup
model_gpu = copy.deepcopy(model) model_gpu = copy.deepcopy(model)
@ -4339,7 +4342,7 @@ class DistributedTest:
_group, _group_id, rank = self._init_global_test() _group, _group_id, rank = self._init_global_test()
# cpu training setup # cpu training setup
model_base = Net() model_base = DDP_NET
# DDP-CPU training setup # DDP-CPU training setup
model_DDP = copy.deepcopy(model_base) model_DDP = copy.deepcopy(model_base)
@ -5488,7 +5491,7 @@ class DistributedTest:
def _test_DistributedDataParallel_with_amp(self, grad_is_view=False): def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
torch.manual_seed(31415) torch.manual_seed(31415)
# Creates model and optimizer in default precision # Creates model and optimizer in default precision
model = Net().cuda() model = copy.deepcopy(DDP_NET).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.03) optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
# Creates a GradScaler once at the beginning of training. # Creates a GradScaler once at the beginning of training.
@ -5573,7 +5576,7 @@ class DistributedTest:
# as baseline # as baseline
# cpu training setup # cpu training setup
model = BatchNormNet() if affine else BatchNormNet(affine=False) model = BN_NET if affine else BN_NET_NO_AFFINE
# single gpu training setup # single gpu training setup
model_gpu = copy.deepcopy(model) model_gpu = copy.deepcopy(model)
@ -5623,7 +5626,6 @@ class DistributedTest:
def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view): def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
learning_rate = 0.03 learning_rate = 0.03
DDP_NET = Net()
net = torch.nn.parallel.DistributedDataParallel( net = torch.nn.parallel.DistributedDataParallel(
copy.deepcopy(DDP_NET).cuda(), copy.deepcopy(DDP_NET).cuda(),
device_ids=[self.rank], device_ids=[self.rank],
@ -5690,7 +5692,7 @@ class DistributedTest:
learning_rate = 0.03 learning_rate = 0.03
net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
Net().cuda(), device_ids=[self.rank] copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
) )
averager = create_averager() averager = create_averager()
@ -5840,7 +5842,7 @@ class DistributedTest:
bs_offset = int(rank * 2) bs_offset = int(rank * 2)
global_bs = int(num_processes * 2) global_bs = int(num_processes * 2)
model = nn.SyncBatchNorm(2, momentum=0.99) model = ONLY_SBN_NET
model_gpu = copy.deepcopy(model).cuda(rank) model_gpu = copy.deepcopy(model).cuda(rank)
model_DDP = nn.parallel.DistributedDataParallel( model_DDP = nn.parallel.DistributedDataParallel(
model_gpu, device_ids=[rank] model_gpu, device_ids=[rank]
@ -6050,7 +6052,6 @@ class DistributedTest:
def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
self, self,
): ):
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
_group, _group_id, rank = self._init_global_test() _group, _group_id, rank = self._init_global_test()
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
ONLY_SBN_NET.cuda(rank), device_ids=[rank] ONLY_SBN_NET.cuda(rank), device_ids=[rank]
@ -6118,7 +6119,7 @@ class DistributedTest:
def test_DistributedDataParallel_SyncBatchNorm_half(self): def test_DistributedDataParallel_SyncBatchNorm_half(self):
_group, _group_id, rank = self._init_global_test() _group, _group_id, rank = self._init_global_test()
model = BatchNormNet() model = copy.deepcopy(BN_NET)
model = model.half() model = model.half()
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
@ -6134,7 +6135,7 @@ class DistributedTest:
def _test_ddp_logging_data(self, is_gpu): def _test_ddp_logging_data(self, is_gpu):
rank = dist.get_rank() rank = dist.get_rank()
model_DDP = Net() model_DDP = copy.deepcopy(DDP_NET)
if is_gpu: if is_gpu:
model_DDP = nn.parallel.DistributedDataParallel( model_DDP = nn.parallel.DistributedDataParallel(
model_DDP.cuda(rank), device_ids=[rank] model_DDP.cuda(rank), device_ids=[rank]
@ -6410,7 +6411,7 @@ class DistributedTest:
BACKEND == "nccl", "nccl does not support DDP on CPU models" BACKEND == "nccl", "nccl does not support DDP on CPU models"
) )
def test_static_graph_api_cpu(self): def test_static_graph_api_cpu(self):
model_DDP = nn.parallel.DistributedDataParallel(Net()) model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
expected_err = "should be called before training loop starts" expected_err = "should be called before training loop starts"
with self.assertRaisesRegex(RuntimeError, expected_err): with self.assertRaisesRegex(RuntimeError, expected_err):
local_bs = 2 local_bs = 2
@ -6643,7 +6644,7 @@ class DistributedTest:
def _test_allgather_object(self, subgroup=None): def _test_allgather_object(self, subgroup=None):
# Only set device for NCCL backend since it must use GPUs. # Only set device for NCCL backend since it must use GPUs.
gather_objects = create_collectives_object_test_list() gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
backend = os.environ["BACKEND"] backend = os.environ["BACKEND"]
if backend == "nccl": if backend == "nccl":
@ -6687,7 +6688,7 @@ class DistributedTest:
def _test_gather_object(self, pg=None): def _test_gather_object(self, pg=None):
# Ensure stateful objects can be gathered # Ensure stateful objects can be gathered
gather_objects = create_collectives_object_test_list() gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
my_rank = dist.get_rank(pg) my_rank = dist.get_rank(pg)
backend = os.environ["BACKEND"] backend = os.environ["BACKEND"]
@ -7257,7 +7258,7 @@ class DistributedTest:
return x return x
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
model_bn = BatchNormNet() model_bn = BN_NET
model_bn = nn.SyncBatchNorm.convert_sync_batchnorm( model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
copy.deepcopy(model_bn) copy.deepcopy(model_bn)
).cuda(self.rank) ).cuda(self.rank)
@ -7553,7 +7554,7 @@ class DistributedTest:
loss.backward() loss.backward()
def _test_broadcast_object_list(self, group=None): def _test_broadcast_object_list(self, group=None):
gather_objects = create_collectives_object_test_list() gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
# Only set device for NCCL backend since it must use GPUs. # Only set device for NCCL backend since it must use GPUs.
# Case where rank != GPU device. # Case where rank != GPU device.
@ -8277,11 +8278,10 @@ class DistributedTest:
@require_backend_is_available({"gloo"}) @require_backend_is_available({"gloo"})
def test_scatter_object_list(self): def test_scatter_object_list(self):
src_rank = 0 src_rank = 0
collectives_object_test_list = create_collectives_object_test_list()
scatter_list = ( scatter_list = (
collectives_object_test_list COLLECTIVES_OBJECT_TEST_LIST
if self.rank == src_rank if self.rank == src_rank
else [None for _ in collectives_object_test_list] else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
) )
world_size = dist.get_world_size() world_size = dist.get_world_size()
scatter_list = scatter_list[:world_size] scatter_list = scatter_list[:world_size]
@ -8294,8 +8294,8 @@ class DistributedTest:
dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
self.assertEqual( self.assertEqual(
output_obj_list[0], output_obj_list[0],
collectives_object_test_list[ COLLECTIVES_OBJECT_TEST_LIST[
self.rank % len(collectives_object_test_list) self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
], ],
) )
# Ensure errors are raised upon incorrect arguments. # Ensure errors are raised upon incorrect arguments.
@ -9981,7 +9981,7 @@ class DistributedTest:
"Only Nccl & Gloo backend support DistributedDataParallel", "Only Nccl & Gloo backend support DistributedDataParallel",
) )
def test_sync_bn_logged(self): def test_sync_bn_logged(self):
model = BatchNormNet() model = BN_NET
rank = self.rank rank = self.rank
# single gpu training setup # single gpu training setup
model_gpu = model.cuda(rank) model_gpu = model.cuda(rank)