mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"
This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3. Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
This commit is contained in:
@ -104,31 +104,6 @@ except ImportError:
|
||||
|
||||
MI300_ARCH = ("gfx942",)
|
||||
|
||||
class ProfilingMode(Enum):
|
||||
LEGACY = 1
|
||||
SIMPLE = 2
|
||||
PROFILING = 3
|
||||
|
||||
# Set by parse_cmd_line_args() if called
|
||||
CI_FUNCTORCH_ROOT = ""
|
||||
CI_PT_ROOT = ""
|
||||
CI_TEST_PREFIX = ""
|
||||
DISABLED_TESTS_FILE = ""
|
||||
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
|
||||
LOG_SUFFIX = ""
|
||||
PYTEST_SINGLE_TEST = ""
|
||||
REPEAT_COUNT = 0
|
||||
RERUN_DISABLED_TESTS = False
|
||||
RUN_PARALLEL = 0
|
||||
SEED : Optional[int] = None
|
||||
SHOWLOCALS = False
|
||||
SLOW_TESTS_FILE = ""
|
||||
TEST_BAILOUTS = False
|
||||
TEST_DISCOVER = False
|
||||
TEST_IN_SUBPROCESS = False
|
||||
TEST_SAVE_XML = ""
|
||||
UNITTEST_ARGS : list[str] = []
|
||||
USE_PYTEST = False
|
||||
|
||||
def freeze_rng_state(*args, **kwargs):
|
||||
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
|
||||
@ -863,6 +838,11 @@ class decorateIf(_TestParametrizer):
|
||||
yield (test_wrapper, test_name, {}, decorator_fn)
|
||||
|
||||
|
||||
class ProfilingMode(Enum):
|
||||
LEGACY = 1
|
||||
SIMPLE = 2
|
||||
PROFILING = 3
|
||||
|
||||
def cppProfilingFlagsToProfilingMode():
|
||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||
@ -881,7 +861,6 @@ def cppProfilingFlagsToProfilingMode():
|
||||
def enable_profiling_mode_for_profiling_tests():
|
||||
old_prof_exec_state = False
|
||||
old_prof_mode_state = False
|
||||
assert GRAPH_EXECUTOR
|
||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||
@ -916,7 +895,6 @@ meth_call = torch._C.ScriptMethod.__call__
|
||||
def prof_callable(callable, *args, **kwargs):
|
||||
if 'profile_and_replay' in kwargs:
|
||||
del kwargs['profile_and_replay']
|
||||
assert GRAPH_EXECUTOR
|
||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
callable(*args, **kwargs)
|
||||
@ -946,93 +924,72 @@ def _get_test_report_path():
|
||||
test_source = override if override is not None else 'python-unittest'
|
||||
return os.path.join('test-reports', test_source)
|
||||
|
||||
def parse_cmd_line_args():
|
||||
global CI_FUNCTORCH_ROOT
|
||||
global CI_PT_ROOT
|
||||
global CI_TEST_PREFIX
|
||||
global DISABLED_TESTS_FILE
|
||||
global GRAPH_EXECUTOR
|
||||
global LOG_SUFFIX
|
||||
global PYTEST_SINGLE_TEST
|
||||
global REPEAT_COUNT
|
||||
global RERUN_DISABLED_TESTS
|
||||
global RUN_PARALLEL
|
||||
global SEED
|
||||
global SHOWLOCALS
|
||||
global SLOW_TESTS_FILE
|
||||
global TEST_BAILOUTS
|
||||
global TEST_DISCOVER
|
||||
global TEST_IN_SUBPROCESS
|
||||
global TEST_SAVE_XML
|
||||
global UNITTEST_ARGS
|
||||
global USE_PYTEST
|
||||
|
||||
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
||||
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
||||
parser.add_argument('--subprocess', action='store_true',
|
||||
help='whether to run each test in a subprocess')
|
||||
parser.add_argument('--seed', type=int, default=1234)
|
||||
parser.add_argument('--accept', action='store_true')
|
||||
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
||||
parser.add_argument('--repeat', type=int, default=1)
|
||||
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
||||
parser.add_argument('--use-pytest', action='store_true')
|
||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||
const=_get_test_report_path(),
|
||||
default=_get_test_report_path() if IS_CI else None)
|
||||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
||||
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
||||
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
||||
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
||||
parser.add_argument('--subprocess', action='store_true',
|
||||
help='whether to run each test in a subprocess')
|
||||
parser.add_argument('--seed', type=int, default=1234)
|
||||
parser.add_argument('--accept', action='store_true')
|
||||
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
||||
parser.add_argument('--repeat', type=int, default=1)
|
||||
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
||||
parser.add_argument('--use-pytest', action='store_true')
|
||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||
const=_get_test_report_path(),
|
||||
default=_get_test_report_path() if IS_CI else None)
|
||||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
||||
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
||||
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||
|
||||
# Only run when -h or --help flag is active to display both unittest and parser help messages.
|
||||
def run_unittest_help(argv):
|
||||
unittest.main(argv=argv)
|
||||
def run_unittest_help(argv):
|
||||
unittest.main(argv=argv)
|
||||
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
||||
help_thread.start()
|
||||
help_thread.join()
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
||||
help_thread.start()
|
||||
help_thread.join()
|
||||
|
||||
args, remaining = parser.parse_known_args()
|
||||
if args.jit_executor == 'legacy':
|
||||
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
||||
elif args.jit_executor == 'profiling':
|
||||
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
||||
elif args.jit_executor == 'simple':
|
||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
||||
else:
|
||||
# infer flags based on the default settings
|
||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||
args, remaining = parser.parse_known_args()
|
||||
if args.jit_executor == 'legacy':
|
||||
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
||||
elif args.jit_executor == 'profiling':
|
||||
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
||||
elif args.jit_executor == 'simple':
|
||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
||||
else:
|
||||
# infer flags based on the default settings
|
||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||
|
||||
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
||||
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
||||
|
||||
SLOW_TESTS_FILE = args.import_slow_tests
|
||||
DISABLED_TESTS_FILE = args.import_disabled_tests
|
||||
LOG_SUFFIX = args.log_suffix
|
||||
RUN_PARALLEL = args.run_parallel
|
||||
TEST_BAILOUTS = args.test_bailouts
|
||||
USE_PYTEST = args.use_pytest
|
||||
PYTEST_SINGLE_TEST = args.pytest_single_test
|
||||
TEST_DISCOVER = args.discover_tests
|
||||
TEST_IN_SUBPROCESS = args.subprocess
|
||||
TEST_SAVE_XML = args.save_xml
|
||||
REPEAT_COUNT = args.repeat
|
||||
SEED = args.seed
|
||||
SHOWLOCALS = args.showlocals
|
||||
if not getattr(expecttest, "ACCEPT", False):
|
||||
expecttest.ACCEPT = args.accept
|
||||
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||
torch.manual_seed(SEED)
|
||||
SLOW_TESTS_FILE = args.import_slow_tests
|
||||
DISABLED_TESTS_FILE = args.import_disabled_tests
|
||||
LOG_SUFFIX = args.log_suffix
|
||||
RUN_PARALLEL = args.run_parallel
|
||||
TEST_BAILOUTS = args.test_bailouts
|
||||
USE_PYTEST = args.use_pytest
|
||||
PYTEST_SINGLE_TEST = args.pytest_single_test
|
||||
TEST_DISCOVER = args.discover_tests
|
||||
TEST_IN_SUBPROCESS = args.subprocess
|
||||
TEST_SAVE_XML = args.save_xml
|
||||
REPEAT_COUNT = args.repeat
|
||||
SEED = args.seed
|
||||
SHOWLOCALS = args.showlocals
|
||||
if not getattr(expecttest, "ACCEPT", False):
|
||||
expecttest.ACCEPT = args.accept
|
||||
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
# CI Prefix path used only on CI environment
|
||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
|
||||
def wait_for_process(p, timeout=None):
|
||||
try:
|
||||
@ -1181,9 +1138,7 @@ def lint_test_case_extension(suite):
|
||||
return succeed
|
||||
|
||||
|
||||
def get_report_path(argv=None, pytest=False):
|
||||
if argv is None:
|
||||
argv = UNITTEST_ARGS
|
||||
def get_report_path(argv=UNITTEST_ARGS, pytest=False):
|
||||
test_filename = sanitize_test_filename(argv[0])
|
||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||
test_report_path = os.path.join(test_report_path, test_filename)
|
||||
@ -1234,11 +1189,7 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]:
|
||||
return test_collector_plugin.tests
|
||||
|
||||
|
||||
def run_tests(argv=None):
|
||||
parse_cmd_line_args()
|
||||
if argv is None:
|
||||
argv = UNITTEST_ARGS
|
||||
|
||||
def run_tests(argv=UNITTEST_ARGS):
|
||||
# import test files.
|
||||
if SLOW_TESTS_FILE:
|
||||
if os.path.exists(SLOW_TESTS_FILE):
|
||||
@ -1804,7 +1755,6 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
|
||||
if not isinstance(fn, type):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
assert GRAPH_EXECUTOR
|
||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||
raise unittest.SkipTest(msg)
|
||||
else:
|
||||
@ -2435,19 +2385,7 @@ def get_function_arglist(func):
|
||||
return inspect.getfullargspec(func).args
|
||||
|
||||
|
||||
def set_rng_seed(seed=None):
|
||||
if seed is None:
|
||||
if SEED is not None:
|
||||
seed = SEED
|
||||
else:
|
||||
# Can't assert here: this function is called by TestCase.setUp() and some out of tree tests inherit from that class.
|
||||
# So just print a warning and hardcode the seed.
|
||||
seed = 1234
|
||||
msg = ("set_rng_seed() was called without providing a seed and the command line "
|
||||
f"arguments haven't been parsed so the seed will be set to {seed}. "
|
||||
"To remove this warning make sure your test is run via run_tests() or "
|
||||
"parse_cmd_line_args() is called before set_rng_seed() is called.")
|
||||
warnings.warn(msg)
|
||||
def set_rng_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
if TEST_NUMPY:
|
||||
@ -3470,7 +3408,7 @@ class TestCase(expecttest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
check_if_enable(self)
|
||||
set_rng_seed()
|
||||
set_rng_seed(SEED)
|
||||
|
||||
# Save global check sparse tensor invariants state that can be
|
||||
# restored from tearDown:
|
||||
|
Reference in New Issue
Block a user