Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"

This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3.

Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
This commit is contained in:
PyTorch MergeBot
2025-08-04 20:37:39 +00:00
parent d4109a0f99
commit 356ac3103a
16 changed files with 109 additions and 255 deletions

View File

@ -104,31 +104,6 @@ except ImportError:
MI300_ARCH = ("gfx942",)
class ProfilingMode(Enum):
LEGACY = 1
SIMPLE = 2
PROFILING = 3
# Set by parse_cmd_line_args() if called
CI_FUNCTORCH_ROOT = ""
CI_PT_ROOT = ""
CI_TEST_PREFIX = ""
DISABLED_TESTS_FILE = ""
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
LOG_SUFFIX = ""
PYTEST_SINGLE_TEST = ""
REPEAT_COUNT = 0
RERUN_DISABLED_TESTS = False
RUN_PARALLEL = 0
SEED : Optional[int] = None
SHOWLOCALS = False
SLOW_TESTS_FILE = ""
TEST_BAILOUTS = False
TEST_DISCOVER = False
TEST_IN_SUBPROCESS = False
TEST_SAVE_XML = ""
UNITTEST_ARGS : list[str] = []
USE_PYTEST = False
def freeze_rng_state(*args, **kwargs):
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
@ -863,6 +838,11 @@ class decorateIf(_TestParametrizer):
yield (test_wrapper, test_name, {}, decorator_fn)
class ProfilingMode(Enum):
LEGACY = 1
SIMPLE = 2
PROFILING = 3
def cppProfilingFlagsToProfilingMode():
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
@ -881,7 +861,6 @@ def cppProfilingFlagsToProfilingMode():
def enable_profiling_mode_for_profiling_tests():
old_prof_exec_state = False
old_prof_mode_state = False
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
@ -916,7 +895,6 @@ meth_call = torch._C.ScriptMethod.__call__
def prof_callable(callable, *args, **kwargs):
if 'profile_and_replay' in kwargs:
del kwargs['profile_and_replay']
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
with enable_profiling_mode_for_profiling_tests():
callable(*args, **kwargs)
@ -946,93 +924,72 @@ def _get_test_report_path():
test_source = override if override is not None else 'python-unittest'
return os.path.join('test-reports', test_source)
def parse_cmd_line_args():
global CI_FUNCTORCH_ROOT
global CI_PT_ROOT
global CI_TEST_PREFIX
global DISABLED_TESTS_FILE
global GRAPH_EXECUTOR
global LOG_SUFFIX
global PYTEST_SINGLE_TEST
global REPEAT_COUNT
global RERUN_DISABLED_TESTS
global RUN_PARALLEL
global SEED
global SHOWLOCALS
global SLOW_TESTS_FILE
global TEST_BAILOUTS
global TEST_DISCOVER
global TEST_IN_SUBPROCESS
global TEST_SAVE_XML
global UNITTEST_ARGS
global USE_PYTEST
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
parser.add_argument('--subprocess', action='store_true',
help='whether to run each test in a subprocess')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--accept', action='store_true')
parser.add_argument('--jit-executor', '--jit_executor', type=str)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
parser.add_argument('--rerun-disabled-tests', action='store_true')
parser.add_argument('--pytest-single-test', type=str, nargs=1)
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
parser.add_argument('--subprocess', action='store_true',
help='whether to run each test in a subprocess')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--accept', action='store_true')
parser.add_argument('--jit-executor', '--jit_executor', type=str)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
parser.add_argument('--rerun-disabled-tests', action='store_true')
parser.add_argument('--pytest-single-test', type=str, nargs=1)
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
# Only run when -h or --help flag is active to display both unittest and parser help messages.
def run_unittest_help(argv):
unittest.main(argv=argv)
def run_unittest_help(argv):
unittest.main(argv=argv)
if '-h' in sys.argv or '--help' in sys.argv:
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
help_thread.start()
help_thread.join()
if '-h' in sys.argv or '--help' in sys.argv:
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
help_thread.start()
help_thread.join()
args, remaining = parser.parse_known_args()
if args.jit_executor == 'legacy':
GRAPH_EXECUTOR = ProfilingMode.LEGACY
elif args.jit_executor == 'profiling':
GRAPH_EXECUTOR = ProfilingMode.PROFILING
elif args.jit_executor == 'simple':
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
else:
# infer flags based on the default settings
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
args, remaining = parser.parse_known_args()
if args.jit_executor == 'legacy':
GRAPH_EXECUTOR = ProfilingMode.LEGACY
elif args.jit_executor == 'profiling':
GRAPH_EXECUTOR = ProfilingMode.PROFILING
elif args.jit_executor == 'simple':
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
else:
# infer flags based on the default settings
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
SLOW_TESTS_FILE = args.import_slow_tests
DISABLED_TESTS_FILE = args.import_disabled_tests
LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
USE_PYTEST = args.use_pytest
PYTEST_SINGLE_TEST = args.pytest_single_test
TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml
REPEAT_COUNT = args.repeat
SEED = args.seed
SHOWLOCALS = args.showlocals
if not getattr(expecttest, "ACCEPT", False):
expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)
SLOW_TESTS_FILE = args.import_slow_tests
DISABLED_TESTS_FILE = args.import_disabled_tests
LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
USE_PYTEST = args.use_pytest
PYTEST_SINGLE_TEST = args.pytest_single_test
TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml
REPEAT_COUNT = args.repeat
SEED = args.seed
SHOWLOCALS = args.showlocals
if not getattr(expecttest, "ACCEPT", False):
expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)
# CI Prefix path used only on CI environment
CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
def wait_for_process(p, timeout=None):
try:
@ -1181,9 +1138,7 @@ def lint_test_case_extension(suite):
return succeed
def get_report_path(argv=None, pytest=False):
if argv is None:
argv = UNITTEST_ARGS
def get_report_path(argv=UNITTEST_ARGS, pytest=False):
test_filename = sanitize_test_filename(argv[0])
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename)
@ -1234,11 +1189,7 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]:
return test_collector_plugin.tests
def run_tests(argv=None):
parse_cmd_line_args()
if argv is None:
argv = UNITTEST_ARGS
def run_tests(argv=UNITTEST_ARGS):
# import test files.
if SLOW_TESTS_FILE:
if os.path.exists(SLOW_TESTS_FILE):
@ -1804,7 +1755,6 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
if not isinstance(fn, type):
@wraps(fn)
def wrapper(*args, **kwargs):
assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
raise unittest.SkipTest(msg)
else:
@ -2435,19 +2385,7 @@ def get_function_arglist(func):
return inspect.getfullargspec(func).args
def set_rng_seed(seed=None):
if seed is None:
if SEED is not None:
seed = SEED
else:
# Can't assert here: this function is called by TestCase.setUp() and some out of tree tests inherit from that class.
# So just print a warning and hardcode the seed.
seed = 1234
msg = ("set_rng_seed() was called without providing a seed and the command line "
f"arguments haven't been parsed so the seed will be set to {seed}. "
"To remove this warning make sure your test is run via run_tests() or "
"parse_cmd_line_args() is called before set_rng_seed() is called.")
warnings.warn(msg)
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
if TEST_NUMPY:
@ -3470,7 +3408,7 @@ class TestCase(expecttest.TestCase):
def setUp(self):
check_if_enable(self)
set_rng_seed()
set_rng_seed(SEED)
# Save global check sparse tensor invariants state that can be
# restored from tearDown: