mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Stop parsing command line arguments every time common_utils is imported. (#156703)
Last PR in the series to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs: https://github.com/pytorch/pytorch/pull/154612 https://github.com/pytorch/pytorch/pull/154628 https://github.com/pytorch/pytorch/pull/154715 https://github.com/pytorch/pytorch/pull/154716 https://github.com/pytorch/pytorch/pull/154725 https://github.com/pytorch/pytorch/pull/154728 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156703 Approved by: https://github.com/clee2000
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6329524d8
commit
ac7b4e7fe4
@ -21,6 +21,16 @@ from _pytest.terminal import _get_raw_skip_reason
|
||||
from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
|
||||
|
||||
|
||||
try:
|
||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
||||
except ImportError:
|
||||
# Temporary workaround needed until parse_cmd_line_args makes it into a nightlye because
|
||||
# main / PR's tests are sometimes run against the previous day's nightly which won't
|
||||
# have this function.
|
||||
def parse_cmd_line_args():
|
||||
pass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest._code.code import ReprFileLocation
|
||||
|
||||
@ -83,6 +93,7 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
|
||||
|
||||
def pytest_configure(config: Config) -> None:
|
||||
parse_cmd_line_args()
|
||||
xmlpath = config.option.xmlpath_reruns
|
||||
# Prevent opening xmllog on worker nodes (xdist).
|
||||
if xmlpath and not hasattr(config, "workerinput"):
|
||||
|
@ -27,6 +27,9 @@ from torch.testing._internal.jit_utils import (
|
||||
)
|
||||
|
||||
|
||||
assert GRAPH_EXECUTOR is not None
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
||||
)
|
||||
|
@ -3,6 +3,13 @@
|
||||
|
||||
import torch
|
||||
|
||||
if __name__ == '__main__':
|
||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
||||
|
||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
||||
# before instantiating tests.
|
||||
parse_cmd_line_args()
|
||||
|
||||
# This is how we include tests located in test/jit/...
|
||||
# They are included here so that they are invoked when you call `test_jit.py`,
|
||||
# do not run these test files directly.
|
||||
@ -97,7 +104,7 @@ import torch.nn.functional as F
|
||||
from torch.testing._internal import jit_utils
|
||||
from torch.testing._internal.common_jit import check_against_reference
|
||||
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
|
||||
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
|
||||
GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \
|
||||
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
|
||||
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
|
||||
skipIfCrossRef, skipIfTorchDynamo
|
||||
@ -158,6 +165,7 @@ def doAutodiffCheck(testname):
|
||||
if "test_t_" in testname or testname == "test_t":
|
||||
return False
|
||||
|
||||
assert GRAPH_EXECUTOR
|
||||
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
|
||||
return False
|
||||
|
||||
@ -201,6 +209,7 @@ def doAutodiffCheck(testname):
|
||||
return testname not in test_exceptions
|
||||
|
||||
|
||||
assert GRAPH_EXECUTOR
|
||||
# TODO: enable TE in PE when all tests are fixed
|
||||
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
|
||||
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
|
||||
|
@ -5,12 +5,17 @@ from torch.cuda.amp import autocast
|
||||
from typing import Optional
|
||||
|
||||
import unittest
|
||||
from test_jit import JitTestCase
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo
|
||||
from torch.testing import FileCheck
|
||||
from jit.test_models import MnistNet
|
||||
|
||||
if __name__ == '__main__':
|
||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
||||
# before instantiating tests.
|
||||
parse_cmd_line_args()
|
||||
|
||||
from test_jit import JitTestCase
|
||||
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
|
@ -9,6 +9,13 @@ import torch.nn.functional as F
|
||||
from torch.testing import FileCheck
|
||||
from unittest import skipIf
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
||||
|
||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
||||
# before instantiating tests.
|
||||
parse_cmd_line_args()
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
|
||||
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
|
||||
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
|
||||
|
@ -2,6 +2,14 @@
|
||||
|
||||
import sys
|
||||
sys.argv.append("--jit-executor=legacy")
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
||||
|
||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
||||
# before instantiating tests.
|
||||
parse_cmd_line_args()
|
||||
|
||||
from test_jit_fuser import * # noqa: F403
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -22,6 +22,13 @@ from torch.testing import FileCheck
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._get_graph_executor_optimize(True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
||||
|
||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
||||
# before instantiating tests.
|
||||
parse_cmd_line_args()
|
||||
|
||||
from itertools import combinations, permutations, product
|
||||
from textwrap import dedent
|
||||
|
||||
|
@ -2,7 +2,14 @@
|
||||
|
||||
import sys
|
||||
sys.argv.append("--jit-executor=legacy")
|
||||
from test_jit import * # noqa: F403
|
||||
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests
|
||||
|
||||
if __name__ == '__main__':
|
||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
||||
# before instantiating tests.
|
||||
parse_cmd_line_args()
|
||||
|
||||
from test_jit import * # noqa: F403, F401
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -33,6 +33,7 @@ import torch.nn as nn
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
from torch._logging._internal import trace_log
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import (
|
||||
FILE_SCHEMA,
|
||||
find_free_port,
|
||||
@ -772,7 +773,12 @@ class MultiProcessTestCase(TestCase):
|
||||
process = proc(
|
||||
target=self.__class__._run,
|
||||
name="process " + str(rank),
|
||||
args=(rank, self._current_test_name(), self.file_name, child_conn),
|
||||
args=(
|
||||
rank,
|
||||
self._current_test_name(),
|
||||
self.file_name,
|
||||
child_conn,
|
||||
),
|
||||
kwargs={
|
||||
"fake_pg": getattr(self, "fake_pg", False),
|
||||
},
|
||||
@ -849,6 +855,7 @@ class MultiProcessTestCase(TestCase):
|
||||
torch._C._set_print_stack_traces_on_fatal_signal(True)
|
||||
# Show full C++ stacktraces when a Python error originating from C++ is raised.
|
||||
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
|
||||
common_utils.set_rng_seed()
|
||||
|
||||
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
||||
# We're retrieving a corresponding test and executing it.
|
||||
@ -1670,6 +1677,10 @@ class MultiProcContinuousTest(TestCase):
|
||||
self.rank = cls.rank
|
||||
self.world_size = cls.world_size
|
||||
test_fn = getattr(self, test_name)
|
||||
|
||||
# Ensure all the ranks use the same seed.
|
||||
common_utils.set_rng_seed()
|
||||
|
||||
# Run the test function
|
||||
test_fn(**kwargs)
|
||||
|
||||
|
@ -58,6 +58,7 @@ from torch.testing._internal.common_distributed import (
|
||||
from torch.testing._internal.common_utils import (
|
||||
FILE_SCHEMA,
|
||||
get_cycles_per_ms,
|
||||
set_rng_seed,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TEST_XPU,
|
||||
@ -1228,6 +1229,7 @@ class FSDPTest(MultiProcessTestCase):
|
||||
dist.barrier(device_ids=device_ids)
|
||||
|
||||
torch._dynamo.reset()
|
||||
set_rng_seed()
|
||||
self.run_test(test_name, pipe)
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
@ -15,6 +15,7 @@ import torch.cuda
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import _reduction as _Reduction
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
||||
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
|
||||
@ -1078,6 +1079,7 @@ def single_batch_reference_fn(input, parameters, module):
|
||||
|
||||
|
||||
def get_new_module_tests():
|
||||
common_utils.set_rng_seed()
|
||||
new_module_tests = [
|
||||
poissonnllloss_no_reduce_test(),
|
||||
bceloss_no_reduce_test(),
|
||||
|
@ -101,9 +101,35 @@ except ImportError:
|
||||
has_pytest = False
|
||||
|
||||
|
||||
SEED = 1234
|
||||
MI300_ARCH = ("gfx942",)
|
||||
MI200_ARCH = ("gfx90a")
|
||||
|
||||
class ProfilingMode(Enum):
|
||||
LEGACY = 1
|
||||
SIMPLE = 2
|
||||
PROFILING = 3
|
||||
|
||||
# Set by parse_cmd_line_args() if called
|
||||
CI_FUNCTORCH_ROOT = ""
|
||||
CI_PT_ROOT = ""
|
||||
CI_TEST_PREFIX = ""
|
||||
DISABLED_TESTS_FILE = ""
|
||||
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
|
||||
LOG_SUFFIX = ""
|
||||
PYTEST_SINGLE_TEST = ""
|
||||
REPEAT_COUNT = 0
|
||||
RERUN_DISABLED_TESTS = False
|
||||
RUN_PARALLEL = 0
|
||||
SHOWLOCALS = False
|
||||
SLOW_TESTS_FILE = ""
|
||||
TEST_BAILOUTS = False
|
||||
TEST_DISCOVER = False
|
||||
TEST_IN_SUBPROCESS = False
|
||||
TEST_SAVE_XML = ""
|
||||
UNITTEST_ARGS : list[str] = []
|
||||
USE_PYTEST = False
|
||||
|
||||
def freeze_rng_state(*args, **kwargs):
|
||||
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
|
||||
|
||||
@ -838,11 +864,6 @@ class decorateIf(_TestParametrizer):
|
||||
yield (test_wrapper, test_name, {}, decorator_fn)
|
||||
|
||||
|
||||
class ProfilingMode(Enum):
|
||||
LEGACY = 1
|
||||
SIMPLE = 2
|
||||
PROFILING = 3
|
||||
|
||||
def cppProfilingFlagsToProfilingMode():
|
||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||
@ -861,6 +882,7 @@ def cppProfilingFlagsToProfilingMode():
|
||||
def enable_profiling_mode_for_profiling_tests():
|
||||
old_prof_exec_state = False
|
||||
old_prof_mode_state = False
|
||||
assert GRAPH_EXECUTOR
|
||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||
@ -895,6 +917,7 @@ meth_call = torch._C.ScriptMethod.__call__
|
||||
def prof_callable(callable, *args, **kwargs):
|
||||
if 'profile_and_replay' in kwargs:
|
||||
del kwargs['profile_and_replay']
|
||||
assert GRAPH_EXECUTOR
|
||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
callable(*args, **kwargs)
|
||||
@ -924,72 +947,91 @@ def _get_test_report_path():
|
||||
test_source = override if override is not None else 'python-unittest'
|
||||
return os.path.join('test-reports', test_source)
|
||||
|
||||
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
||||
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
||||
parser.add_argument('--subprocess', action='store_true',
|
||||
help='whether to run each test in a subprocess')
|
||||
parser.add_argument('--seed', type=int, default=1234)
|
||||
parser.add_argument('--accept', action='store_true')
|
||||
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
||||
parser.add_argument('--repeat', type=int, default=1)
|
||||
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
||||
parser.add_argument('--use-pytest', action='store_true')
|
||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||
const=_get_test_report_path(),
|
||||
default=_get_test_report_path() if IS_CI else None)
|
||||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
||||
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
||||
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||
def parse_cmd_line_args():
|
||||
global CI_FUNCTORCH_ROOT
|
||||
global CI_PT_ROOT
|
||||
global CI_TEST_PREFIX
|
||||
global DISABLED_TESTS_FILE
|
||||
global GRAPH_EXECUTOR
|
||||
global LOG_SUFFIX
|
||||
global PYTEST_SINGLE_TEST
|
||||
global REPEAT_COUNT
|
||||
global RERUN_DISABLED_TESTS
|
||||
global RUN_PARALLEL
|
||||
global SHOWLOCALS
|
||||
global SLOW_TESTS_FILE
|
||||
global TEST_BAILOUTS
|
||||
global TEST_DISCOVER
|
||||
global TEST_IN_SUBPROCESS
|
||||
global TEST_SAVE_XML
|
||||
global UNITTEST_ARGS
|
||||
global USE_PYTEST
|
||||
|
||||
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
||||
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
||||
parser.add_argument('--subprocess', action='store_true',
|
||||
help='whether to run each test in a subprocess')
|
||||
parser.add_argument('--accept', action='store_true')
|
||||
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
||||
parser.add_argument('--repeat', type=int, default=1)
|
||||
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
||||
parser.add_argument('--use-pytest', action='store_true')
|
||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||
const=_get_test_report_path(),
|
||||
default=_get_test_report_path() if IS_CI else None)
|
||||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
||||
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
||||
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||
|
||||
# Only run when -h or --help flag is active to display both unittest and parser help messages.
|
||||
def run_unittest_help(argv):
|
||||
unittest.main(argv=argv)
|
||||
def run_unittest_help(argv):
|
||||
unittest.main(argv=argv)
|
||||
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
||||
help_thread.start()
|
||||
help_thread.join()
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
||||
help_thread.start()
|
||||
help_thread.join()
|
||||
|
||||
args, remaining = parser.parse_known_args()
|
||||
if args.jit_executor == 'legacy':
|
||||
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
||||
elif args.jit_executor == 'profiling':
|
||||
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
||||
elif args.jit_executor == 'simple':
|
||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
||||
else:
|
||||
# infer flags based on the default settings
|
||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||
args, remaining = parser.parse_known_args()
|
||||
if args.jit_executor == 'legacy':
|
||||
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
||||
elif args.jit_executor == 'profiling':
|
||||
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
||||
elif args.jit_executor == 'simple':
|
||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
||||
else:
|
||||
# infer flags based on the default settings
|
||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||
|
||||
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
||||
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
||||
|
||||
SLOW_TESTS_FILE = args.import_slow_tests
|
||||
DISABLED_TESTS_FILE = args.import_disabled_tests
|
||||
LOG_SUFFIX = args.log_suffix
|
||||
RUN_PARALLEL = args.run_parallel
|
||||
TEST_BAILOUTS = args.test_bailouts
|
||||
USE_PYTEST = args.use_pytest
|
||||
PYTEST_SINGLE_TEST = args.pytest_single_test
|
||||
TEST_DISCOVER = args.discover_tests
|
||||
TEST_IN_SUBPROCESS = args.subprocess
|
||||
TEST_SAVE_XML = args.save_xml
|
||||
REPEAT_COUNT = args.repeat
|
||||
SEED = args.seed
|
||||
SHOWLOCALS = args.showlocals
|
||||
if not getattr(expecttest, "ACCEPT", False):
|
||||
expecttest.ACCEPT = args.accept
|
||||
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||
torch.manual_seed(SEED)
|
||||
SLOW_TESTS_FILE = args.import_slow_tests
|
||||
DISABLED_TESTS_FILE = args.import_disabled_tests
|
||||
LOG_SUFFIX = args.log_suffix
|
||||
RUN_PARALLEL = args.run_parallel
|
||||
TEST_BAILOUTS = args.test_bailouts
|
||||
USE_PYTEST = args.use_pytest
|
||||
PYTEST_SINGLE_TEST = args.pytest_single_test
|
||||
TEST_DISCOVER = args.discover_tests
|
||||
TEST_IN_SUBPROCESS = args.subprocess
|
||||
TEST_SAVE_XML = args.save_xml
|
||||
REPEAT_COUNT = args.repeat
|
||||
SHOWLOCALS = args.showlocals
|
||||
if not getattr(expecttest, "ACCEPT", False):
|
||||
expecttest.ACCEPT = args.accept
|
||||
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||
|
||||
set_rng_seed()
|
||||
|
||||
# CI Prefix path used only on CI environment
|
||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
|
||||
def wait_for_process(p, timeout=None):
|
||||
try:
|
||||
@ -1138,7 +1180,9 @@ def lint_test_case_extension(suite):
|
||||
return succeed
|
||||
|
||||
|
||||
def get_report_path(argv=UNITTEST_ARGS, pytest=False):
|
||||
def get_report_path(argv=None, pytest=False):
|
||||
if argv is None:
|
||||
argv = UNITTEST_ARGS
|
||||
test_filename = sanitize_test_filename(argv[0])
|
||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||
test_report_path = os.path.join(test_report_path, test_filename)
|
||||
@ -1189,7 +1233,11 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]:
|
||||
return test_collector_plugin.tests
|
||||
|
||||
|
||||
def run_tests(argv=UNITTEST_ARGS):
|
||||
def run_tests(argv=None):
|
||||
parse_cmd_line_args()
|
||||
if argv is None:
|
||||
argv = UNITTEST_ARGS
|
||||
|
||||
# import test files.
|
||||
if SLOW_TESTS_FILE:
|
||||
if os.path.exists(SLOW_TESTS_FILE):
|
||||
@ -1759,6 +1807,7 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
|
||||
if not isinstance(fn, type):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
assert GRAPH_EXECUTOR
|
||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||
raise unittest.SkipTest(msg)
|
||||
else:
|
||||
@ -2379,7 +2428,9 @@ def get_function_arglist(func):
|
||||
return inspect.getfullargspec(func).args
|
||||
|
||||
|
||||
def set_rng_seed(seed):
|
||||
def set_rng_seed(seed=None):
|
||||
if seed is None:
|
||||
seed = SEED
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
if TEST_NUMPY:
|
||||
@ -3402,7 +3453,7 @@ class TestCase(expecttest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
check_if_enable(self)
|
||||
set_rng_seed(SEED)
|
||||
set_rng_seed()
|
||||
|
||||
# Save global check sparse tensor invariants state that can be
|
||||
# restored from tearDown:
|
||||
|
@ -137,16 +137,18 @@ class Foo:
|
||||
f = Foo(10)
|
||||
f.bar = 1
|
||||
|
||||
foo_cpu_tensor = Foo(torch.randn(3, 3))
|
||||
|
||||
# Defer instantiation until the seed is set so that randn() returns the same
|
||||
# values in all processes.
|
||||
def create_collectives_object_test_list():
|
||||
return [
|
||||
{"key1": 3, "key2": 4, "key3": {"nested": True}},
|
||||
f,
|
||||
Foo(torch.randn(3, 3)),
|
||||
"foo",
|
||||
[1, 2, True, "string", [4, 5, "nested"]],
|
||||
]
|
||||
|
||||
COLLECTIVES_OBJECT_TEST_LIST = [
|
||||
{"key1": 3, "key2": 4, "key3": {"nested": True}},
|
||||
f,
|
||||
foo_cpu_tensor,
|
||||
"foo",
|
||||
[1, 2, True, "string", [4, 5, "nested"]],
|
||||
]
|
||||
|
||||
# Allowlist of distributed backends where profiling collectives is supported.
|
||||
PROFILING_SUPPORTED_BACKENDS = [
|
||||
@ -396,12 +398,6 @@ class ControlFlowToyModel(nn.Module):
|
||||
return F.relu(self.lin1(x))
|
||||
|
||||
|
||||
DDP_NET = Net()
|
||||
BN_NET = BatchNormNet()
|
||||
BN_NET_NO_AFFINE = BatchNormNet(affine=False)
|
||||
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
|
||||
|
||||
|
||||
def get_timeout(test_id):
|
||||
test_name = test_id.split(".")[-1]
|
||||
if test_name in CUSTOMIZED_TIMEOUT:
|
||||
@ -4293,7 +4289,7 @@ class DistributedTest:
|
||||
# as baseline
|
||||
|
||||
# cpu training setup
|
||||
model = DDP_NET
|
||||
model = Net()
|
||||
|
||||
# single gpu training setup
|
||||
model_gpu = copy.deepcopy(model)
|
||||
@ -4348,7 +4344,7 @@ class DistributedTest:
|
||||
_group, _group_id, rank = self._init_global_test()
|
||||
|
||||
# cpu training setup
|
||||
model_base = DDP_NET
|
||||
model_base = Net()
|
||||
|
||||
# DDP-CPU training setup
|
||||
model_DDP = copy.deepcopy(model_base)
|
||||
@ -5497,7 +5493,7 @@ class DistributedTest:
|
||||
def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
|
||||
torch.manual_seed(31415)
|
||||
# Creates model and optimizer in default precision
|
||||
model = copy.deepcopy(DDP_NET).cuda()
|
||||
model = Net().cuda()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
|
||||
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
@ -5582,7 +5578,7 @@ class DistributedTest:
|
||||
# as baseline
|
||||
|
||||
# cpu training setup
|
||||
model = BN_NET if affine else BN_NET_NO_AFFINE
|
||||
model = BatchNormNet() if affine else BatchNormNet(affine=False)
|
||||
|
||||
# single gpu training setup
|
||||
model_gpu = copy.deepcopy(model)
|
||||
@ -5632,6 +5628,7 @@ class DistributedTest:
|
||||
def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
|
||||
learning_rate = 0.03
|
||||
|
||||
DDP_NET = Net()
|
||||
net = torch.nn.parallel.DistributedDataParallel(
|
||||
copy.deepcopy(DDP_NET).cuda(),
|
||||
device_ids=[self.rank],
|
||||
@ -5698,7 +5695,7 @@ class DistributedTest:
|
||||
learning_rate = 0.03
|
||||
|
||||
net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
|
||||
copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
|
||||
Net().cuda(), device_ids=[self.rank]
|
||||
)
|
||||
|
||||
averager = create_averager()
|
||||
@ -5848,7 +5845,7 @@ class DistributedTest:
|
||||
bs_offset = int(rank * 2)
|
||||
global_bs = int(num_processes * 2)
|
||||
|
||||
model = ONLY_SBN_NET
|
||||
model = nn.SyncBatchNorm(2, momentum=0.99)
|
||||
model_gpu = copy.deepcopy(model).cuda(rank)
|
||||
model_DDP = nn.parallel.DistributedDataParallel(
|
||||
model_gpu, device_ids=[rank]
|
||||
@ -6058,6 +6055,7 @@ class DistributedTest:
|
||||
def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
|
||||
self,
|
||||
):
|
||||
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
|
||||
_group, _group_id, rank = self._init_global_test()
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
ONLY_SBN_NET.cuda(rank), device_ids=[rank]
|
||||
@ -6125,7 +6123,7 @@ class DistributedTest:
|
||||
def test_DistributedDataParallel_SyncBatchNorm_half(self):
|
||||
_group, _group_id, rank = self._init_global_test()
|
||||
|
||||
model = copy.deepcopy(BN_NET)
|
||||
model = BatchNormNet()
|
||||
model = model.half()
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
@ -6141,7 +6139,7 @@ class DistributedTest:
|
||||
|
||||
def _test_ddp_logging_data(self, is_gpu):
|
||||
rank = dist.get_rank()
|
||||
model_DDP = copy.deepcopy(DDP_NET)
|
||||
model_DDP = Net()
|
||||
if is_gpu:
|
||||
model_DDP = nn.parallel.DistributedDataParallel(
|
||||
model_DDP.cuda(rank), device_ids=[rank]
|
||||
@ -6417,7 +6415,7 @@ class DistributedTest:
|
||||
BACKEND == "nccl", "nccl does not support DDP on CPU models"
|
||||
)
|
||||
def test_static_graph_api_cpu(self):
|
||||
model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
|
||||
model_DDP = nn.parallel.DistributedDataParallel(Net())
|
||||
expected_err = "should be called before training loop starts"
|
||||
with self.assertRaisesRegex(RuntimeError, expected_err):
|
||||
local_bs = 2
|
||||
@ -6650,7 +6648,7 @@ class DistributedTest:
|
||||
def _test_allgather_object(self, subgroup=None):
|
||||
# Only set device for NCCL backend since it must use GPUs.
|
||||
|
||||
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||
gather_objects = create_collectives_object_test_list()
|
||||
|
||||
backend = os.environ["BACKEND"]
|
||||
if backend == "nccl":
|
||||
@ -6694,7 +6692,7 @@ class DistributedTest:
|
||||
|
||||
def _test_gather_object(self, pg=None):
|
||||
# Ensure stateful objects can be gathered
|
||||
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||
gather_objects = create_collectives_object_test_list()
|
||||
my_rank = dist.get_rank(pg)
|
||||
|
||||
backend = os.environ["BACKEND"]
|
||||
@ -7264,7 +7262,7 @@ class DistributedTest:
|
||||
return x
|
||||
|
||||
torch.cuda.set_device(self.rank)
|
||||
model_bn = BN_NET
|
||||
model_bn = BatchNormNet()
|
||||
model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
|
||||
copy.deepcopy(model_bn)
|
||||
).cuda(self.rank)
|
||||
@ -7560,7 +7558,7 @@ class DistributedTest:
|
||||
loss.backward()
|
||||
|
||||
def _test_broadcast_object_list(self, group=None):
|
||||
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||
gather_objects = create_collectives_object_test_list()
|
||||
|
||||
# Only set device for NCCL backend since it must use GPUs.
|
||||
# Case where rank != GPU device.
|
||||
@ -8284,10 +8282,11 @@ class DistributedTest:
|
||||
@require_backend_is_available({"gloo"})
|
||||
def test_scatter_object_list(self):
|
||||
src_rank = 0
|
||||
collectives_object_test_list = create_collectives_object_test_list()
|
||||
scatter_list = (
|
||||
COLLECTIVES_OBJECT_TEST_LIST
|
||||
collectives_object_test_list
|
||||
if self.rank == src_rank
|
||||
else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
|
||||
else [None for _ in collectives_object_test_list]
|
||||
)
|
||||
world_size = dist.get_world_size()
|
||||
scatter_list = scatter_list[:world_size]
|
||||
@ -8300,8 +8299,8 @@ class DistributedTest:
|
||||
dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
|
||||
self.assertEqual(
|
||||
output_obj_list[0],
|
||||
COLLECTIVES_OBJECT_TEST_LIST[
|
||||
self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
|
||||
collectives_object_test_list[
|
||||
self.rank % len(collectives_object_test_list)
|
||||
],
|
||||
)
|
||||
# Ensure errors are raised upon incorrect arguments.
|
||||
@ -9987,7 +9986,7 @@ class DistributedTest:
|
||||
"Only Nccl & Gloo backend support DistributedDataParallel",
|
||||
)
|
||||
def test_sync_bn_logged(self):
|
||||
model = BN_NET
|
||||
model = BatchNormNet()
|
||||
rank = self.rank
|
||||
# single gpu training setup
|
||||
model_gpu = model.cuda(rank)
|
||||
|
Reference in New Issue
Block a user