mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"
This reverts commit ac7b4e7fe4d233dcd7f6343d42b4fa3d64bce548. Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/clee2000 due to failing internally D80206253, see above comment for details ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3362156908))
This commit is contained in:
@ -21,16 +21,6 @@ from _pytest.terminal import _get_raw_skip_reason
|
|||||||
from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
|
from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
except ImportError:
|
|
||||||
# Temporary workaround needed until parse_cmd_line_args makes it into a nightlye because
|
|
||||||
# main / PR's tests are sometimes run against the previous day's nightly which won't
|
|
||||||
# have this function.
|
|
||||||
def parse_cmd_line_args():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from _pytest._code.code import ReprFileLocation
|
from _pytest._code.code import ReprFileLocation
|
||||||
|
|
||||||
@ -93,7 +83,6 @@ def pytest_addoption(parser: Parser) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def pytest_configure(config: Config) -> None:
|
def pytest_configure(config: Config) -> None:
|
||||||
parse_cmd_line_args()
|
|
||||||
xmlpath = config.option.xmlpath_reruns
|
xmlpath = config.option.xmlpath_reruns
|
||||||
# Prevent opening xmllog on worker nodes (xdist).
|
# Prevent opening xmllog on worker nodes (xdist).
|
||||||
if xmlpath and not hasattr(config, "workerinput"):
|
if xmlpath and not hasattr(config, "workerinput"):
|
||||||
|
@ -27,9 +27,6 @@ from torch.testing._internal.jit_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
assert GRAPH_EXECUTOR is not None
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
||||||
)
|
)
|
||||||
|
@ -3,13 +3,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
# This is how we include tests located in test/jit/...
|
# This is how we include tests located in test/jit/...
|
||||||
# They are included here so that they are invoked when you call `test_jit.py`,
|
# They are included here so that they are invoked when you call `test_jit.py`,
|
||||||
# do not run these test files directly.
|
# do not run these test files directly.
|
||||||
@ -104,7 +97,7 @@ import torch.nn.functional as F
|
|||||||
from torch.testing._internal import jit_utils
|
from torch.testing._internal import jit_utils
|
||||||
from torch.testing._internal.common_jit import check_against_reference
|
from torch.testing._internal.common_jit import check_against_reference
|
||||||
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
|
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
|
||||||
GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \
|
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
|
||||||
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
|
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
|
||||||
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
|
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
|
||||||
skipIfCrossRef, skipIfTorchDynamo
|
skipIfCrossRef, skipIfTorchDynamo
|
||||||
@ -165,7 +158,6 @@ def doAutodiffCheck(testname):
|
|||||||
if "test_t_" in testname or testname == "test_t":
|
if "test_t_" in testname or testname == "test_t":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
|
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -209,7 +201,6 @@ def doAutodiffCheck(testname):
|
|||||||
return testname not in test_exceptions
|
return testname not in test_exceptions
|
||||||
|
|
||||||
|
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
# TODO: enable TE in PE when all tests are fixed
|
# TODO: enable TE in PE when all tests are fixed
|
||||||
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
|
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
|
||||||
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
|
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
|
||||||
|
@ -5,17 +5,12 @@ from torch.cuda.amp import autocast
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from test_jit import JitTestCase
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo
|
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from jit.test_models import MnistNet
|
from jit.test_models import MnistNet
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from test_jit import JitTestCase
|
|
||||||
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||||
|
@ -9,13 +9,6 @@ import torch.nn.functional as F
|
|||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from unittest import skipIf
|
from unittest import skipIf
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
|
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
|
||||||
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
|
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
|
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
|
||||||
|
@ -2,14 +2,6 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit-executor=legacy")
|
sys.argv.append("--jit-executor=legacy")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from test_jit_fuser import * # noqa: F403
|
from test_jit_fuser import * # noqa: F403
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -22,13 +22,6 @@ from torch.testing import FileCheck
|
|||||||
torch._C._jit_set_profiling_executor(True)
|
torch._C._jit_set_profiling_executor(True)
|
||||||
torch._C._get_graph_executor_optimize(True)
|
torch._C._get_graph_executor_optimize(True)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from itertools import combinations, permutations, product
|
from itertools import combinations, permutations, product
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
|
@ -2,14 +2,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit-executor=legacy")
|
sys.argv.append("--jit-executor=legacy")
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests
|
from test_jit import * # noqa: F403
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from test_jit import * # noqa: F403, F401
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -33,7 +33,6 @@ import torch.nn as nn
|
|||||||
from torch._C._autograd import DeviceType
|
from torch._C._autograd import DeviceType
|
||||||
from torch._C._distributed_c10d import _SymmetricMemory
|
from torch._C._distributed_c10d import _SymmetricMemory
|
||||||
from torch._logging._internal import trace_log
|
from torch._logging._internal import trace_log
|
||||||
from torch.testing._internal import common_utils
|
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
FILE_SCHEMA,
|
FILE_SCHEMA,
|
||||||
find_free_port,
|
find_free_port,
|
||||||
@ -773,12 +772,7 @@ class MultiProcessTestCase(TestCase):
|
|||||||
process = proc(
|
process = proc(
|
||||||
target=self.__class__._run,
|
target=self.__class__._run,
|
||||||
name="process " + str(rank),
|
name="process " + str(rank),
|
||||||
args=(
|
args=(rank, self._current_test_name(), self.file_name, child_conn),
|
||||||
rank,
|
|
||||||
self._current_test_name(),
|
|
||||||
self.file_name,
|
|
||||||
child_conn,
|
|
||||||
),
|
|
||||||
kwargs={
|
kwargs={
|
||||||
"fake_pg": getattr(self, "fake_pg", False),
|
"fake_pg": getattr(self, "fake_pg", False),
|
||||||
},
|
},
|
||||||
@ -855,7 +849,6 @@ class MultiProcessTestCase(TestCase):
|
|||||||
torch._C._set_print_stack_traces_on_fatal_signal(True)
|
torch._C._set_print_stack_traces_on_fatal_signal(True)
|
||||||
# Show full C++ stacktraces when a Python error originating from C++ is raised.
|
# Show full C++ stacktraces when a Python error originating from C++ is raised.
|
||||||
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
|
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
|
||||||
common_utils.set_rng_seed()
|
|
||||||
|
|
||||||
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
||||||
# We're retrieving a corresponding test and executing it.
|
# We're retrieving a corresponding test and executing it.
|
||||||
@ -1677,10 +1670,6 @@ class MultiProcContinuousTest(TestCase):
|
|||||||
self.rank = cls.rank
|
self.rank = cls.rank
|
||||||
self.world_size = cls.world_size
|
self.world_size = cls.world_size
|
||||||
test_fn = getattr(self, test_name)
|
test_fn = getattr(self, test_name)
|
||||||
|
|
||||||
# Ensure all the ranks use the same seed.
|
|
||||||
common_utils.set_rng_seed()
|
|
||||||
|
|
||||||
# Run the test function
|
# Run the test function
|
||||||
test_fn(**kwargs)
|
test_fn(**kwargs)
|
||||||
|
|
||||||
|
@ -58,7 +58,6 @@ from torch.testing._internal.common_distributed import (
|
|||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
FILE_SCHEMA,
|
FILE_SCHEMA,
|
||||||
get_cycles_per_ms,
|
get_cycles_per_ms,
|
||||||
set_rng_seed,
|
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
TEST_XPU,
|
TEST_XPU,
|
||||||
@ -1229,7 +1228,6 @@ class FSDPTest(MultiProcessTestCase):
|
|||||||
dist.barrier(device_ids=device_ids)
|
dist.barrier(device_ids=device_ids)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
set_rng_seed()
|
|
||||||
self.run_test(test_name, pipe)
|
self.run_test(test_name, pipe)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@ import torch.cuda
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import _reduction as _Reduction
|
from torch.nn import _reduction as _Reduction
|
||||||
from torch.testing._internal import common_utils
|
|
||||||
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
||||||
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
|
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
|
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
|
||||||
@ -1079,7 +1078,6 @@ def single_batch_reference_fn(input, parameters, module):
|
|||||||
|
|
||||||
|
|
||||||
def get_new_module_tests():
|
def get_new_module_tests():
|
||||||
common_utils.set_rng_seed()
|
|
||||||
new_module_tests = [
|
new_module_tests = [
|
||||||
poissonnllloss_no_reduce_test(),
|
poissonnllloss_no_reduce_test(),
|
||||||
bceloss_no_reduce_test(),
|
bceloss_no_reduce_test(),
|
||||||
|
@ -101,35 +101,9 @@ except ImportError:
|
|||||||
has_pytest = False
|
has_pytest = False
|
||||||
|
|
||||||
|
|
||||||
SEED = 1234
|
|
||||||
MI300_ARCH = ("gfx942",)
|
MI300_ARCH = ("gfx942",)
|
||||||
MI200_ARCH = ("gfx90a")
|
MI200_ARCH = ("gfx90a")
|
||||||
|
|
||||||
class ProfilingMode(Enum):
|
|
||||||
LEGACY = 1
|
|
||||||
SIMPLE = 2
|
|
||||||
PROFILING = 3
|
|
||||||
|
|
||||||
# Set by parse_cmd_line_args() if called
|
|
||||||
CI_FUNCTORCH_ROOT = ""
|
|
||||||
CI_PT_ROOT = ""
|
|
||||||
CI_TEST_PREFIX = ""
|
|
||||||
DISABLED_TESTS_FILE = ""
|
|
||||||
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
|
|
||||||
LOG_SUFFIX = ""
|
|
||||||
PYTEST_SINGLE_TEST = ""
|
|
||||||
REPEAT_COUNT = 0
|
|
||||||
RERUN_DISABLED_TESTS = False
|
|
||||||
RUN_PARALLEL = 0
|
|
||||||
SHOWLOCALS = False
|
|
||||||
SLOW_TESTS_FILE = ""
|
|
||||||
TEST_BAILOUTS = False
|
|
||||||
TEST_DISCOVER = False
|
|
||||||
TEST_IN_SUBPROCESS = False
|
|
||||||
TEST_SAVE_XML = ""
|
|
||||||
UNITTEST_ARGS : list[str] = []
|
|
||||||
USE_PYTEST = False
|
|
||||||
|
|
||||||
def freeze_rng_state(*args, **kwargs):
|
def freeze_rng_state(*args, **kwargs):
|
||||||
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
|
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
|
||||||
|
|
||||||
@ -864,6 +838,11 @@ class decorateIf(_TestParametrizer):
|
|||||||
yield (test_wrapper, test_name, {}, decorator_fn)
|
yield (test_wrapper, test_name, {}, decorator_fn)
|
||||||
|
|
||||||
|
|
||||||
|
class ProfilingMode(Enum):
|
||||||
|
LEGACY = 1
|
||||||
|
SIMPLE = 2
|
||||||
|
PROFILING = 3
|
||||||
|
|
||||||
def cppProfilingFlagsToProfilingMode():
|
def cppProfilingFlagsToProfilingMode():
|
||||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||||
@ -882,7 +861,6 @@ def cppProfilingFlagsToProfilingMode():
|
|||||||
def enable_profiling_mode_for_profiling_tests():
|
def enable_profiling_mode_for_profiling_tests():
|
||||||
old_prof_exec_state = False
|
old_prof_exec_state = False
|
||||||
old_prof_mode_state = False
|
old_prof_mode_state = False
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||||
@ -917,7 +895,6 @@ meth_call = torch._C.ScriptMethod.__call__
|
|||||||
def prof_callable(callable, *args, **kwargs):
|
def prof_callable(callable, *args, **kwargs):
|
||||||
if 'profile_and_replay' in kwargs:
|
if 'profile_and_replay' in kwargs:
|
||||||
del kwargs['profile_and_replay']
|
del kwargs['profile_and_replay']
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||||
with enable_profiling_mode_for_profiling_tests():
|
with enable_profiling_mode_for_profiling_tests():
|
||||||
callable(*args, **kwargs)
|
callable(*args, **kwargs)
|
||||||
@ -947,91 +924,72 @@ def _get_test_report_path():
|
|||||||
test_source = override if override is not None else 'python-unittest'
|
test_source = override if override is not None else 'python-unittest'
|
||||||
return os.path.join('test-reports', test_source)
|
return os.path.join('test-reports', test_source)
|
||||||
|
|
||||||
def parse_cmd_line_args():
|
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
||||||
global CI_FUNCTORCH_ROOT
|
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
||||||
global CI_PT_ROOT
|
parser.add_argument('--subprocess', action='store_true',
|
||||||
global CI_TEST_PREFIX
|
help='whether to run each test in a subprocess')
|
||||||
global DISABLED_TESTS_FILE
|
parser.add_argument('--seed', type=int, default=1234)
|
||||||
global GRAPH_EXECUTOR
|
parser.add_argument('--accept', action='store_true')
|
||||||
global LOG_SUFFIX
|
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
||||||
global PYTEST_SINGLE_TEST
|
parser.add_argument('--repeat', type=int, default=1)
|
||||||
global REPEAT_COUNT
|
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
||||||
global RERUN_DISABLED_TESTS
|
parser.add_argument('--use-pytest', action='store_true')
|
||||||
global RUN_PARALLEL
|
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||||
global SHOWLOCALS
|
const=_get_test_report_path(),
|
||||||
global SLOW_TESTS_FILE
|
default=_get_test_report_path() if IS_CI else None)
|
||||||
global TEST_BAILOUTS
|
parser.add_argument('--discover-tests', action='store_true')
|
||||||
global TEST_DISCOVER
|
parser.add_argument('--log-suffix', type=str, default="")
|
||||||
global TEST_IN_SUBPROCESS
|
parser.add_argument('--run-parallel', type=int, default=1)
|
||||||
global TEST_SAVE_XML
|
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
||||||
global UNITTEST_ARGS
|
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
||||||
global USE_PYTEST
|
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
||||||
|
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
||||||
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||||
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
|
||||||
parser.add_argument('--subprocess', action='store_true',
|
|
||||||
help='whether to run each test in a subprocess')
|
|
||||||
parser.add_argument('--accept', action='store_true')
|
|
||||||
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
|
||||||
parser.add_argument('--repeat', type=int, default=1)
|
|
||||||
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
|
||||||
parser.add_argument('--use-pytest', action='store_true')
|
|
||||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
|
||||||
const=_get_test_report_path(),
|
|
||||||
default=_get_test_report_path() if IS_CI else None)
|
|
||||||
parser.add_argument('--discover-tests', action='store_true')
|
|
||||||
parser.add_argument('--log-suffix', type=str, default="")
|
|
||||||
parser.add_argument('--run-parallel', type=int, default=1)
|
|
||||||
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
|
||||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
|
||||||
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
|
||||||
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
|
||||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
|
||||||
|
|
||||||
# Only run when -h or --help flag is active to display both unittest and parser help messages.
|
# Only run when -h or --help flag is active to display both unittest and parser help messages.
|
||||||
def run_unittest_help(argv):
|
def run_unittest_help(argv):
|
||||||
unittest.main(argv=argv)
|
unittest.main(argv=argv)
|
||||||
|
|
||||||
if '-h' in sys.argv or '--help' in sys.argv:
|
if '-h' in sys.argv or '--help' in sys.argv:
|
||||||
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
||||||
help_thread.start()
|
help_thread.start()
|
||||||
help_thread.join()
|
help_thread.join()
|
||||||
|
|
||||||
args, remaining = parser.parse_known_args()
|
args, remaining = parser.parse_known_args()
|
||||||
if args.jit_executor == 'legacy':
|
if args.jit_executor == 'legacy':
|
||||||
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
||||||
elif args.jit_executor == 'profiling':
|
elif args.jit_executor == 'profiling':
|
||||||
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
||||||
elif args.jit_executor == 'simple':
|
elif args.jit_executor == 'simple':
|
||||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
||||||
else:
|
else:
|
||||||
# infer flags based on the default settings
|
# infer flags based on the default settings
|
||||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||||
|
|
||||||
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
||||||
|
|
||||||
SLOW_TESTS_FILE = args.import_slow_tests
|
SLOW_TESTS_FILE = args.import_slow_tests
|
||||||
DISABLED_TESTS_FILE = args.import_disabled_tests
|
DISABLED_TESTS_FILE = args.import_disabled_tests
|
||||||
LOG_SUFFIX = args.log_suffix
|
LOG_SUFFIX = args.log_suffix
|
||||||
RUN_PARALLEL = args.run_parallel
|
RUN_PARALLEL = args.run_parallel
|
||||||
TEST_BAILOUTS = args.test_bailouts
|
TEST_BAILOUTS = args.test_bailouts
|
||||||
USE_PYTEST = args.use_pytest
|
USE_PYTEST = args.use_pytest
|
||||||
PYTEST_SINGLE_TEST = args.pytest_single_test
|
PYTEST_SINGLE_TEST = args.pytest_single_test
|
||||||
TEST_DISCOVER = args.discover_tests
|
TEST_DISCOVER = args.discover_tests
|
||||||
TEST_IN_SUBPROCESS = args.subprocess
|
TEST_IN_SUBPROCESS = args.subprocess
|
||||||
TEST_SAVE_XML = args.save_xml
|
TEST_SAVE_XML = args.save_xml
|
||||||
REPEAT_COUNT = args.repeat
|
REPEAT_COUNT = args.repeat
|
||||||
SHOWLOCALS = args.showlocals
|
SEED = args.seed
|
||||||
if not getattr(expecttest, "ACCEPT", False):
|
SHOWLOCALS = args.showlocals
|
||||||
expecttest.ACCEPT = args.accept
|
if not getattr(expecttest, "ACCEPT", False):
|
||||||
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
expecttest.ACCEPT = args.accept
|
||||||
|
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||||
set_rng_seed()
|
torch.manual_seed(SEED)
|
||||||
|
|
||||||
# CI Prefix path used only on CI environment
|
# CI Prefix path used only on CI environment
|
||||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||||
|
|
||||||
def wait_for_process(p, timeout=None):
|
def wait_for_process(p, timeout=None):
|
||||||
try:
|
try:
|
||||||
@ -1180,9 +1138,7 @@ def lint_test_case_extension(suite):
|
|||||||
return succeed
|
return succeed
|
||||||
|
|
||||||
|
|
||||||
def get_report_path(argv=None, pytest=False):
|
def get_report_path(argv=UNITTEST_ARGS, pytest=False):
|
||||||
if argv is None:
|
|
||||||
argv = UNITTEST_ARGS
|
|
||||||
test_filename = sanitize_test_filename(argv[0])
|
test_filename = sanitize_test_filename(argv[0])
|
||||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||||
test_report_path = os.path.join(test_report_path, test_filename)
|
test_report_path = os.path.join(test_report_path, test_filename)
|
||||||
@ -1233,11 +1189,7 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]:
|
|||||||
return test_collector_plugin.tests
|
return test_collector_plugin.tests
|
||||||
|
|
||||||
|
|
||||||
def run_tests(argv=None):
|
def run_tests(argv=UNITTEST_ARGS):
|
||||||
parse_cmd_line_args()
|
|
||||||
if argv is None:
|
|
||||||
argv = UNITTEST_ARGS
|
|
||||||
|
|
||||||
# import test files.
|
# import test files.
|
||||||
if SLOW_TESTS_FILE:
|
if SLOW_TESTS_FILE:
|
||||||
if os.path.exists(SLOW_TESTS_FILE):
|
if os.path.exists(SLOW_TESTS_FILE):
|
||||||
@ -1807,7 +1759,6 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
|
|||||||
if not isinstance(fn, type):
|
if not isinstance(fn, type):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||||
raise unittest.SkipTest(msg)
|
raise unittest.SkipTest(msg)
|
||||||
else:
|
else:
|
||||||
@ -2428,9 +2379,7 @@ def get_function_arglist(func):
|
|||||||
return inspect.getfullargspec(func).args
|
return inspect.getfullargspec(func).args
|
||||||
|
|
||||||
|
|
||||||
def set_rng_seed(seed=None):
|
def set_rng_seed(seed):
|
||||||
if seed is None:
|
|
||||||
seed = SEED
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
if TEST_NUMPY:
|
if TEST_NUMPY:
|
||||||
@ -3453,7 +3402,7 @@ class TestCase(expecttest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
check_if_enable(self)
|
check_if_enable(self)
|
||||||
set_rng_seed()
|
set_rng_seed(SEED)
|
||||||
|
|
||||||
# Save global check sparse tensor invariants state that can be
|
# Save global check sparse tensor invariants state that can be
|
||||||
# restored from tearDown:
|
# restored from tearDown:
|
||||||
|
@ -137,18 +137,16 @@ class Foo:
|
|||||||
f = Foo(10)
|
f = Foo(10)
|
||||||
f.bar = 1
|
f.bar = 1
|
||||||
|
|
||||||
|
foo_cpu_tensor = Foo(torch.randn(3, 3))
|
||||||
|
|
||||||
# Defer instantiation until the seed is set so that randn() returns the same
|
|
||||||
# values in all processes.
|
|
||||||
def create_collectives_object_test_list():
|
|
||||||
return [
|
|
||||||
{"key1": 3, "key2": 4, "key3": {"nested": True}},
|
|
||||||
f,
|
|
||||||
Foo(torch.randn(3, 3)),
|
|
||||||
"foo",
|
|
||||||
[1, 2, True, "string", [4, 5, "nested"]],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
COLLECTIVES_OBJECT_TEST_LIST = [
|
||||||
|
{"key1": 3, "key2": 4, "key3": {"nested": True}},
|
||||||
|
f,
|
||||||
|
foo_cpu_tensor,
|
||||||
|
"foo",
|
||||||
|
[1, 2, True, "string", [4, 5, "nested"]],
|
||||||
|
]
|
||||||
|
|
||||||
# Allowlist of distributed backends where profiling collectives is supported.
|
# Allowlist of distributed backends where profiling collectives is supported.
|
||||||
PROFILING_SUPPORTED_BACKENDS = [
|
PROFILING_SUPPORTED_BACKENDS = [
|
||||||
@ -398,6 +396,12 @@ class ControlFlowToyModel(nn.Module):
|
|||||||
return F.relu(self.lin1(x))
|
return F.relu(self.lin1(x))
|
||||||
|
|
||||||
|
|
||||||
|
DDP_NET = Net()
|
||||||
|
BN_NET = BatchNormNet()
|
||||||
|
BN_NET_NO_AFFINE = BatchNormNet(affine=False)
|
||||||
|
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
|
||||||
|
|
||||||
|
|
||||||
def get_timeout(test_id):
|
def get_timeout(test_id):
|
||||||
test_name = test_id.split(".")[-1]
|
test_name = test_id.split(".")[-1]
|
||||||
if test_name in CUSTOMIZED_TIMEOUT:
|
if test_name in CUSTOMIZED_TIMEOUT:
|
||||||
@ -4289,7 +4293,7 @@ class DistributedTest:
|
|||||||
# as baseline
|
# as baseline
|
||||||
|
|
||||||
# cpu training setup
|
# cpu training setup
|
||||||
model = Net()
|
model = DDP_NET
|
||||||
|
|
||||||
# single gpu training setup
|
# single gpu training setup
|
||||||
model_gpu = copy.deepcopy(model)
|
model_gpu = copy.deepcopy(model)
|
||||||
@ -4344,7 +4348,7 @@ class DistributedTest:
|
|||||||
_group, _group_id, rank = self._init_global_test()
|
_group, _group_id, rank = self._init_global_test()
|
||||||
|
|
||||||
# cpu training setup
|
# cpu training setup
|
||||||
model_base = Net()
|
model_base = DDP_NET
|
||||||
|
|
||||||
# DDP-CPU training setup
|
# DDP-CPU training setup
|
||||||
model_DDP = copy.deepcopy(model_base)
|
model_DDP = copy.deepcopy(model_base)
|
||||||
@ -5493,7 +5497,7 @@ class DistributedTest:
|
|||||||
def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
|
def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
|
||||||
torch.manual_seed(31415)
|
torch.manual_seed(31415)
|
||||||
# Creates model and optimizer in default precision
|
# Creates model and optimizer in default precision
|
||||||
model = Net().cuda()
|
model = copy.deepcopy(DDP_NET).cuda()
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
|
||||||
|
|
||||||
# Creates a GradScaler once at the beginning of training.
|
# Creates a GradScaler once at the beginning of training.
|
||||||
@ -5578,7 +5582,7 @@ class DistributedTest:
|
|||||||
# as baseline
|
# as baseline
|
||||||
|
|
||||||
# cpu training setup
|
# cpu training setup
|
||||||
model = BatchNormNet() if affine else BatchNormNet(affine=False)
|
model = BN_NET if affine else BN_NET_NO_AFFINE
|
||||||
|
|
||||||
# single gpu training setup
|
# single gpu training setup
|
||||||
model_gpu = copy.deepcopy(model)
|
model_gpu = copy.deepcopy(model)
|
||||||
@ -5628,7 +5632,6 @@ class DistributedTest:
|
|||||||
def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
|
def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
|
||||||
learning_rate = 0.03
|
learning_rate = 0.03
|
||||||
|
|
||||||
DDP_NET = Net()
|
|
||||||
net = torch.nn.parallel.DistributedDataParallel(
|
net = torch.nn.parallel.DistributedDataParallel(
|
||||||
copy.deepcopy(DDP_NET).cuda(),
|
copy.deepcopy(DDP_NET).cuda(),
|
||||||
device_ids=[self.rank],
|
device_ids=[self.rank],
|
||||||
@ -5695,7 +5698,7 @@ class DistributedTest:
|
|||||||
learning_rate = 0.03
|
learning_rate = 0.03
|
||||||
|
|
||||||
net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
|
net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
|
||||||
Net().cuda(), device_ids=[self.rank]
|
copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
|
||||||
)
|
)
|
||||||
|
|
||||||
averager = create_averager()
|
averager = create_averager()
|
||||||
@ -5845,7 +5848,7 @@ class DistributedTest:
|
|||||||
bs_offset = int(rank * 2)
|
bs_offset = int(rank * 2)
|
||||||
global_bs = int(num_processes * 2)
|
global_bs = int(num_processes * 2)
|
||||||
|
|
||||||
model = nn.SyncBatchNorm(2, momentum=0.99)
|
model = ONLY_SBN_NET
|
||||||
model_gpu = copy.deepcopy(model).cuda(rank)
|
model_gpu = copy.deepcopy(model).cuda(rank)
|
||||||
model_DDP = nn.parallel.DistributedDataParallel(
|
model_DDP = nn.parallel.DistributedDataParallel(
|
||||||
model_gpu, device_ids=[rank]
|
model_gpu, device_ids=[rank]
|
||||||
@ -6055,7 +6058,6 @@ class DistributedTest:
|
|||||||
def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
|
def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
|
|
||||||
_group, _group_id, rank = self._init_global_test()
|
_group, _group_id, rank = self._init_global_test()
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
ONLY_SBN_NET.cuda(rank), device_ids=[rank]
|
ONLY_SBN_NET.cuda(rank), device_ids=[rank]
|
||||||
@ -6123,7 +6125,7 @@ class DistributedTest:
|
|||||||
def test_DistributedDataParallel_SyncBatchNorm_half(self):
|
def test_DistributedDataParallel_SyncBatchNorm_half(self):
|
||||||
_group, _group_id, rank = self._init_global_test()
|
_group, _group_id, rank = self._init_global_test()
|
||||||
|
|
||||||
model = BatchNormNet()
|
model = copy.deepcopy(BN_NET)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
@ -6139,7 +6141,7 @@ class DistributedTest:
|
|||||||
|
|
||||||
def _test_ddp_logging_data(self, is_gpu):
|
def _test_ddp_logging_data(self, is_gpu):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
model_DDP = Net()
|
model_DDP = copy.deepcopy(DDP_NET)
|
||||||
if is_gpu:
|
if is_gpu:
|
||||||
model_DDP = nn.parallel.DistributedDataParallel(
|
model_DDP = nn.parallel.DistributedDataParallel(
|
||||||
model_DDP.cuda(rank), device_ids=[rank]
|
model_DDP.cuda(rank), device_ids=[rank]
|
||||||
@ -6415,7 +6417,7 @@ class DistributedTest:
|
|||||||
BACKEND == "nccl", "nccl does not support DDP on CPU models"
|
BACKEND == "nccl", "nccl does not support DDP on CPU models"
|
||||||
)
|
)
|
||||||
def test_static_graph_api_cpu(self):
|
def test_static_graph_api_cpu(self):
|
||||||
model_DDP = nn.parallel.DistributedDataParallel(Net())
|
model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
|
||||||
expected_err = "should be called before training loop starts"
|
expected_err = "should be called before training loop starts"
|
||||||
with self.assertRaisesRegex(RuntimeError, expected_err):
|
with self.assertRaisesRegex(RuntimeError, expected_err):
|
||||||
local_bs = 2
|
local_bs = 2
|
||||||
@ -6648,7 +6650,7 @@ class DistributedTest:
|
|||||||
def _test_allgather_object(self, subgroup=None):
|
def _test_allgather_object(self, subgroup=None):
|
||||||
# Only set device for NCCL backend since it must use GPUs.
|
# Only set device for NCCL backend since it must use GPUs.
|
||||||
|
|
||||||
gather_objects = create_collectives_object_test_list()
|
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||||
|
|
||||||
backend = os.environ["BACKEND"]
|
backend = os.environ["BACKEND"]
|
||||||
if backend == "nccl":
|
if backend == "nccl":
|
||||||
@ -6692,7 +6694,7 @@ class DistributedTest:
|
|||||||
|
|
||||||
def _test_gather_object(self, pg=None):
|
def _test_gather_object(self, pg=None):
|
||||||
# Ensure stateful objects can be gathered
|
# Ensure stateful objects can be gathered
|
||||||
gather_objects = create_collectives_object_test_list()
|
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||||
my_rank = dist.get_rank(pg)
|
my_rank = dist.get_rank(pg)
|
||||||
|
|
||||||
backend = os.environ["BACKEND"]
|
backend = os.environ["BACKEND"]
|
||||||
@ -7262,7 +7264,7 @@ class DistributedTest:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
torch.cuda.set_device(self.rank)
|
torch.cuda.set_device(self.rank)
|
||||||
model_bn = BatchNormNet()
|
model_bn = BN_NET
|
||||||
model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
|
model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
|
||||||
copy.deepcopy(model_bn)
|
copy.deepcopy(model_bn)
|
||||||
).cuda(self.rank)
|
).cuda(self.rank)
|
||||||
@ -7558,7 +7560,7 @@ class DistributedTest:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def _test_broadcast_object_list(self, group=None):
|
def _test_broadcast_object_list(self, group=None):
|
||||||
gather_objects = create_collectives_object_test_list()
|
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||||
|
|
||||||
# Only set device for NCCL backend since it must use GPUs.
|
# Only set device for NCCL backend since it must use GPUs.
|
||||||
# Case where rank != GPU device.
|
# Case where rank != GPU device.
|
||||||
@ -8282,11 +8284,10 @@ class DistributedTest:
|
|||||||
@require_backend_is_available({"gloo"})
|
@require_backend_is_available({"gloo"})
|
||||||
def test_scatter_object_list(self):
|
def test_scatter_object_list(self):
|
||||||
src_rank = 0
|
src_rank = 0
|
||||||
collectives_object_test_list = create_collectives_object_test_list()
|
|
||||||
scatter_list = (
|
scatter_list = (
|
||||||
collectives_object_test_list
|
COLLECTIVES_OBJECT_TEST_LIST
|
||||||
if self.rank == src_rank
|
if self.rank == src_rank
|
||||||
else [None for _ in collectives_object_test_list]
|
else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
|
||||||
)
|
)
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
scatter_list = scatter_list[:world_size]
|
scatter_list = scatter_list[:world_size]
|
||||||
@ -8299,8 +8300,8 @@ class DistributedTest:
|
|||||||
dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
|
dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
output_obj_list[0],
|
output_obj_list[0],
|
||||||
collectives_object_test_list[
|
COLLECTIVES_OBJECT_TEST_LIST[
|
||||||
self.rank % len(collectives_object_test_list)
|
self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# Ensure errors are raised upon incorrect arguments.
|
# Ensure errors are raised upon incorrect arguments.
|
||||||
@ -9986,7 +9987,7 @@ class DistributedTest:
|
|||||||
"Only Nccl & Gloo backend support DistributedDataParallel",
|
"Only Nccl & Gloo backend support DistributedDataParallel",
|
||||||
)
|
)
|
||||||
def test_sync_bn_logged(self):
|
def test_sync_bn_logged(self):
|
||||||
model = BatchNormNet()
|
model = BN_NET
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
# single gpu training setup
|
# single gpu training setup
|
||||||
model_gpu = model.cuda(rank)
|
model_gpu = model.cuda(rank)
|
||||||
|
Reference in New Issue
Block a user