mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"
This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3. Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
This commit is contained in:
@ -4280,11 +4280,10 @@ class NCCLTraceTestBase(MultiProcessTestCase):
|
|||||||
test_name: str,
|
test_name: str,
|
||||||
file_name: str,
|
file_name: str,
|
||||||
parent_pipe,
|
parent_pipe,
|
||||||
seed: int,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
cls.parent = parent_conn
|
cls.parent = parent_conn
|
||||||
super()._run(rank, test_name, file_name, parent_pipe, seed)
|
super()._run(rank, test_name, file_name, parent_pipe)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_device(self):
|
def local_device(self):
|
||||||
|
@ -27,9 +27,6 @@ from torch.testing._internal.jit_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
assert GRAPH_EXECUTOR is not None
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
||||||
)
|
)
|
||||||
|
@ -35,11 +35,6 @@ class TestCppApiParity(common.TestCase):
|
|||||||
functional_test_params_map = {}
|
functional_test_params_map = {}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# The value of the SEED depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
|
|
||||||
common.parse_cmd_line_args()
|
|
||||||
|
|
||||||
expected_test_params_dicts = []
|
expected_test_params_dicts = []
|
||||||
|
|
||||||
for test_params_dicts, test_instance_class in [
|
for test_params_dicts, test_instance_class in [
|
||||||
|
@ -1008,13 +1008,6 @@ def filter_supported_tests(t):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of the SEED depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||||
# These currently use the legacy nn tests
|
# These currently use the legacy nn tests
|
||||||
supported_tests = [
|
supported_tests = [
|
||||||
|
@ -3,13 +3,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR and SEED depend on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
# This is how we include tests located in test/jit/...
|
# This is how we include tests located in test/jit/...
|
||||||
# They are included here so that they are invoked when you call `test_jit.py`,
|
# They are included here so that they are invoked when you call `test_jit.py`,
|
||||||
# do not run these test files directly.
|
# do not run these test files directly.
|
||||||
@ -104,7 +97,7 @@ import torch.nn.functional as F
|
|||||||
from torch.testing._internal import jit_utils
|
from torch.testing._internal import jit_utils
|
||||||
from torch.testing._internal.common_jit import check_against_reference
|
from torch.testing._internal.common_jit import check_against_reference
|
||||||
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
|
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
|
||||||
GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \
|
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
|
||||||
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
|
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
|
||||||
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
|
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
|
||||||
skipIfCrossRef, skipIfTorchDynamo
|
skipIfCrossRef, skipIfTorchDynamo
|
||||||
@ -165,7 +158,6 @@ def doAutodiffCheck(testname):
|
|||||||
if "test_t_" in testname or testname == "test_t":
|
if "test_t_" in testname or testname == "test_t":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
|
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -209,7 +201,6 @@ def doAutodiffCheck(testname):
|
|||||||
return testname not in test_exceptions
|
return testname not in test_exceptions
|
||||||
|
|
||||||
|
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
# TODO: enable TE in PE when all tests are fixed
|
# TODO: enable TE in PE when all tests are fixed
|
||||||
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
|
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
|
||||||
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
|
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
|
||||||
|
@ -5,17 +5,12 @@ from torch.cuda.amp import autocast
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from test_jit import JitTestCase
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo
|
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from jit.test_models import MnistNet
|
from jit.test_models import MnistNet
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from test_jit import JitTestCase
|
|
||||||
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||||
|
@ -9,13 +9,6 @@ import torch.nn.functional as F
|
|||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from unittest import skipIf
|
from unittest import skipIf
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
|
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
|
||||||
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
|
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
|
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
|
||||||
|
@ -2,14 +2,6 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit-executor=legacy")
|
sys.argv.append("--jit-executor=legacy")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from test_jit_fuser import * # noqa: F403
|
from test_jit_fuser import * # noqa: F403
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -22,13 +22,6 @@ from torch.testing import FileCheck
|
|||||||
torch._C._jit_set_profiling_executor(True)
|
torch._C._jit_set_profiling_executor(True)
|
||||||
torch._C._get_graph_executor_optimize(True)
|
torch._C._get_graph_executor_optimize(True)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from itertools import combinations, permutations, product
|
from itertools import combinations, permutations, product
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
|
@ -2,14 +2,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit-executor=legacy")
|
sys.argv.append("--jit-executor=legacy")
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests
|
from test_jit import * # noqa: F403
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests.
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
from test_jit import * # noqa: F403, F401
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -7643,13 +7643,6 @@ def add_test(test, decorator=None):
|
|||||||
else:
|
else:
|
||||||
add(cuda_test_name, with_tf32_off)
|
add(cuda_test_name, with_tf32_off)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from torch.testing._internal.common_utils import parse_cmd_line_args
|
|
||||||
|
|
||||||
# The value of the SEED depends on command line arguments so make sure they're parsed
|
|
||||||
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
|
|
||||||
parse_cmd_line_args()
|
|
||||||
|
|
||||||
for test_params in module_tests + get_new_module_tests():
|
for test_params in module_tests + get_new_module_tests():
|
||||||
# TODO: CUDA is not implemented yet
|
# TODO: CUDA is not implemented yet
|
||||||
if 'constructor' not in test_params:
|
if 'constructor' not in test_params:
|
||||||
|
@ -32,7 +32,6 @@ import torch.nn as nn
|
|||||||
from torch._C._autograd import DeviceType
|
from torch._C._autograd import DeviceType
|
||||||
from torch._C._distributed_c10d import _SymmetricMemory
|
from torch._C._distributed_c10d import _SymmetricMemory
|
||||||
from torch._logging._internal import trace_log
|
from torch._logging._internal import trace_log
|
||||||
from torch.testing._internal import common_utils
|
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
FILE_SCHEMA,
|
FILE_SCHEMA,
|
||||||
find_free_port,
|
find_free_port,
|
||||||
@ -672,7 +671,6 @@ class MultiProcessTestCase(TestCase):
|
|||||||
if methodName != "runTest":
|
if methodName != "runTest":
|
||||||
method_name = methodName
|
method_name = methodName
|
||||||
super().__init__(method_name)
|
super().__init__(method_name)
|
||||||
self.seed = None
|
|
||||||
try:
|
try:
|
||||||
fn = getattr(self, method_name)
|
fn = getattr(self, method_name)
|
||||||
setattr(self, method_name, self.join_or_run(fn))
|
setattr(self, method_name, self.join_or_run(fn))
|
||||||
@ -717,20 +715,13 @@ class MultiProcessTestCase(TestCase):
|
|||||||
|
|
||||||
def _start_processes(self, proc) -> None:
|
def _start_processes(self, proc) -> None:
|
||||||
self.processes = []
|
self.processes = []
|
||||||
assert common_utils.SEED is not None
|
|
||||||
for rank in range(int(self.world_size)):
|
for rank in range(int(self.world_size)):
|
||||||
parent_conn, child_conn = torch.multiprocessing.Pipe()
|
parent_conn, child_conn = torch.multiprocessing.Pipe()
|
||||||
process = proc(
|
process = proc(
|
||||||
target=self.__class__._run,
|
target=self.__class__._run,
|
||||||
name="process " + str(rank),
|
name="process " + str(rank),
|
||||||
args=(
|
args=(rank, self._current_test_name(), self.file_name, child_conn),
|
||||||
rank,
|
|
||||||
self._current_test_name(),
|
|
||||||
self.file_name,
|
|
||||||
child_conn,
|
|
||||||
),
|
|
||||||
kwargs={
|
kwargs={
|
||||||
"seed": common_utils.SEED,
|
|
||||||
"fake_pg": getattr(self, "fake_pg", False),
|
"fake_pg": getattr(self, "fake_pg", False),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -784,12 +775,11 @@ class MultiProcessTestCase(TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run(
|
def _run(
|
||||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
|
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
self = cls(test_name)
|
self = cls(test_name)
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
self.seed = seed
|
|
||||||
self.run_test(test_name, parent_pipe)
|
self.run_test(test_name, parent_pipe)
|
||||||
|
|
||||||
def run_test(self, test_name: str, parent_pipe) -> None:
|
def run_test(self, test_name: str, parent_pipe) -> None:
|
||||||
@ -808,9 +798,6 @@ class MultiProcessTestCase(TestCase):
|
|||||||
# Show full C++ stacktraces when a Python error originating from C++ is raised.
|
# Show full C++ stacktraces when a Python error originating from C++ is raised.
|
||||||
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
|
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
|
||||||
|
|
||||||
if self.seed is not None:
|
|
||||||
common_utils.set_rng_seed(self.seed)
|
|
||||||
|
|
||||||
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
||||||
# We're retrieving a corresponding test and executing it.
|
# We're retrieving a corresponding test and executing it.
|
||||||
try:
|
try:
|
||||||
@ -1548,7 +1535,7 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run(
|
def _run(
|
||||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
|
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
trace_log.addHandler(logging.NullHandler())
|
trace_log.addHandler(logging.NullHandler())
|
||||||
|
|
||||||
@ -1556,7 +1543,6 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
|
|||||||
self = cls(test_name)
|
self = cls(test_name)
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
self.seed = seed
|
|
||||||
self.run_test(test_name, parent_pipe)
|
self.run_test(test_name, parent_pipe)
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +57,6 @@ from torch.testing._internal.common_distributed import (
|
|||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
FILE_SCHEMA,
|
FILE_SCHEMA,
|
||||||
get_cycles_per_ms,
|
get_cycles_per_ms,
|
||||||
set_rng_seed,
|
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
TEST_XPU,
|
TEST_XPU,
|
||||||
@ -1181,7 +1180,7 @@ class FSDPTest(MultiProcessTestCase):
|
|||||||
return run_subtests(self, *args, **kwargs)
|
return run_subtests(self, *args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): # type: ignore[override]
|
def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override]
|
||||||
self = cls(test_name)
|
self = cls(test_name)
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
@ -1227,7 +1226,6 @@ class FSDPTest(MultiProcessTestCase):
|
|||||||
dist.barrier(device_ids=device_ids)
|
dist.barrier(device_ids=device_ids)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
set_rng_seed(seed)
|
|
||||||
self.run_test(test_name, pipe)
|
self.run_test(test_name, pipe)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@ import torch.cuda
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import _reduction as _Reduction
|
from torch.nn import _reduction as _Reduction
|
||||||
from torch.testing._internal import common_utils
|
|
||||||
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
|
||||||
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
|
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
|
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
|
||||||
@ -1079,7 +1078,6 @@ def single_batch_reference_fn(input, parameters, module):
|
|||||||
|
|
||||||
|
|
||||||
def get_new_module_tests():
|
def get_new_module_tests():
|
||||||
assert common_utils.SEED is not None, "Make sure the seed is set before calling get_new_module_tests()"
|
|
||||||
new_module_tests = [
|
new_module_tests = [
|
||||||
poissonnllloss_no_reduce_test(),
|
poissonnllloss_no_reduce_test(),
|
||||||
bceloss_no_reduce_test(),
|
bceloss_no_reduce_test(),
|
||||||
|
@ -104,31 +104,6 @@ except ImportError:
|
|||||||
|
|
||||||
MI300_ARCH = ("gfx942",)
|
MI300_ARCH = ("gfx942",)
|
||||||
|
|
||||||
class ProfilingMode(Enum):
|
|
||||||
LEGACY = 1
|
|
||||||
SIMPLE = 2
|
|
||||||
PROFILING = 3
|
|
||||||
|
|
||||||
# Set by parse_cmd_line_args() if called
|
|
||||||
CI_FUNCTORCH_ROOT = ""
|
|
||||||
CI_PT_ROOT = ""
|
|
||||||
CI_TEST_PREFIX = ""
|
|
||||||
DISABLED_TESTS_FILE = ""
|
|
||||||
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
|
|
||||||
LOG_SUFFIX = ""
|
|
||||||
PYTEST_SINGLE_TEST = ""
|
|
||||||
REPEAT_COUNT = 0
|
|
||||||
RERUN_DISABLED_TESTS = False
|
|
||||||
RUN_PARALLEL = 0
|
|
||||||
SEED : Optional[int] = None
|
|
||||||
SHOWLOCALS = False
|
|
||||||
SLOW_TESTS_FILE = ""
|
|
||||||
TEST_BAILOUTS = False
|
|
||||||
TEST_DISCOVER = False
|
|
||||||
TEST_IN_SUBPROCESS = False
|
|
||||||
TEST_SAVE_XML = ""
|
|
||||||
UNITTEST_ARGS : list[str] = []
|
|
||||||
USE_PYTEST = False
|
|
||||||
|
|
||||||
def freeze_rng_state(*args, **kwargs):
|
def freeze_rng_state(*args, **kwargs):
|
||||||
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
|
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
|
||||||
@ -863,6 +838,11 @@ class decorateIf(_TestParametrizer):
|
|||||||
yield (test_wrapper, test_name, {}, decorator_fn)
|
yield (test_wrapper, test_name, {}, decorator_fn)
|
||||||
|
|
||||||
|
|
||||||
|
class ProfilingMode(Enum):
|
||||||
|
LEGACY = 1
|
||||||
|
SIMPLE = 2
|
||||||
|
PROFILING = 3
|
||||||
|
|
||||||
def cppProfilingFlagsToProfilingMode():
|
def cppProfilingFlagsToProfilingMode():
|
||||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||||
@ -881,7 +861,6 @@ def cppProfilingFlagsToProfilingMode():
|
|||||||
def enable_profiling_mode_for_profiling_tests():
|
def enable_profiling_mode_for_profiling_tests():
|
||||||
old_prof_exec_state = False
|
old_prof_exec_state = False
|
||||||
old_prof_mode_state = False
|
old_prof_mode_state = False
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||||
@ -916,7 +895,6 @@ meth_call = torch._C.ScriptMethod.__call__
|
|||||||
def prof_callable(callable, *args, **kwargs):
|
def prof_callable(callable, *args, **kwargs):
|
||||||
if 'profile_and_replay' in kwargs:
|
if 'profile_and_replay' in kwargs:
|
||||||
del kwargs['profile_and_replay']
|
del kwargs['profile_and_replay']
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||||
with enable_profiling_mode_for_profiling_tests():
|
with enable_profiling_mode_for_profiling_tests():
|
||||||
callable(*args, **kwargs)
|
callable(*args, **kwargs)
|
||||||
@ -946,93 +924,72 @@ def _get_test_report_path():
|
|||||||
test_source = override if override is not None else 'python-unittest'
|
test_source = override if override is not None else 'python-unittest'
|
||||||
return os.path.join('test-reports', test_source)
|
return os.path.join('test-reports', test_source)
|
||||||
|
|
||||||
def parse_cmd_line_args():
|
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
||||||
global CI_FUNCTORCH_ROOT
|
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
||||||
global CI_PT_ROOT
|
parser.add_argument('--subprocess', action='store_true',
|
||||||
global CI_TEST_PREFIX
|
help='whether to run each test in a subprocess')
|
||||||
global DISABLED_TESTS_FILE
|
parser.add_argument('--seed', type=int, default=1234)
|
||||||
global GRAPH_EXECUTOR
|
parser.add_argument('--accept', action='store_true')
|
||||||
global LOG_SUFFIX
|
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
||||||
global PYTEST_SINGLE_TEST
|
parser.add_argument('--repeat', type=int, default=1)
|
||||||
global REPEAT_COUNT
|
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
||||||
global RERUN_DISABLED_TESTS
|
parser.add_argument('--use-pytest', action='store_true')
|
||||||
global RUN_PARALLEL
|
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||||
global SEED
|
const=_get_test_report_path(),
|
||||||
global SHOWLOCALS
|
default=_get_test_report_path() if IS_CI else None)
|
||||||
global SLOW_TESTS_FILE
|
parser.add_argument('--discover-tests', action='store_true')
|
||||||
global TEST_BAILOUTS
|
parser.add_argument('--log-suffix', type=str, default="")
|
||||||
global TEST_DISCOVER
|
parser.add_argument('--run-parallel', type=int, default=1)
|
||||||
global TEST_IN_SUBPROCESS
|
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
||||||
global TEST_SAVE_XML
|
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
||||||
global UNITTEST_ARGS
|
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
||||||
global USE_PYTEST
|
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
||||||
|
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||||
is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
|
|
||||||
parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
|
|
||||||
parser.add_argument('--subprocess', action='store_true',
|
|
||||||
help='whether to run each test in a subprocess')
|
|
||||||
parser.add_argument('--seed', type=int, default=1234)
|
|
||||||
parser.add_argument('--accept', action='store_true')
|
|
||||||
parser.add_argument('--jit-executor', '--jit_executor', type=str)
|
|
||||||
parser.add_argument('--repeat', type=int, default=1)
|
|
||||||
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
|
||||||
parser.add_argument('--use-pytest', action='store_true')
|
|
||||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
|
||||||
const=_get_test_report_path(),
|
|
||||||
default=_get_test_report_path() if IS_CI else None)
|
|
||||||
parser.add_argument('--discover-tests', action='store_true')
|
|
||||||
parser.add_argument('--log-suffix', type=str, default="")
|
|
||||||
parser.add_argument('--run-parallel', type=int, default=1)
|
|
||||||
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
|
|
||||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
|
||||||
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
|
||||||
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
|
||||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
|
||||||
|
|
||||||
# Only run when -h or --help flag is active to display both unittest and parser help messages.
|
# Only run when -h or --help flag is active to display both unittest and parser help messages.
|
||||||
def run_unittest_help(argv):
|
def run_unittest_help(argv):
|
||||||
unittest.main(argv=argv)
|
unittest.main(argv=argv)
|
||||||
|
|
||||||
if '-h' in sys.argv or '--help' in sys.argv:
|
if '-h' in sys.argv or '--help' in sys.argv:
|
||||||
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
|
||||||
help_thread.start()
|
help_thread.start()
|
||||||
help_thread.join()
|
help_thread.join()
|
||||||
|
|
||||||
args, remaining = parser.parse_known_args()
|
args, remaining = parser.parse_known_args()
|
||||||
if args.jit_executor == 'legacy':
|
if args.jit_executor == 'legacy':
|
||||||
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
GRAPH_EXECUTOR = ProfilingMode.LEGACY
|
||||||
elif args.jit_executor == 'profiling':
|
elif args.jit_executor == 'profiling':
|
||||||
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
GRAPH_EXECUTOR = ProfilingMode.PROFILING
|
||||||
elif args.jit_executor == 'simple':
|
elif args.jit_executor == 'simple':
|
||||||
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
|
||||||
else:
|
else:
|
||||||
# infer flags based on the default settings
|
# infer flags based on the default settings
|
||||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||||
|
|
||||||
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
||||||
|
|
||||||
SLOW_TESTS_FILE = args.import_slow_tests
|
SLOW_TESTS_FILE = args.import_slow_tests
|
||||||
DISABLED_TESTS_FILE = args.import_disabled_tests
|
DISABLED_TESTS_FILE = args.import_disabled_tests
|
||||||
LOG_SUFFIX = args.log_suffix
|
LOG_SUFFIX = args.log_suffix
|
||||||
RUN_PARALLEL = args.run_parallel
|
RUN_PARALLEL = args.run_parallel
|
||||||
TEST_BAILOUTS = args.test_bailouts
|
TEST_BAILOUTS = args.test_bailouts
|
||||||
USE_PYTEST = args.use_pytest
|
USE_PYTEST = args.use_pytest
|
||||||
PYTEST_SINGLE_TEST = args.pytest_single_test
|
PYTEST_SINGLE_TEST = args.pytest_single_test
|
||||||
TEST_DISCOVER = args.discover_tests
|
TEST_DISCOVER = args.discover_tests
|
||||||
TEST_IN_SUBPROCESS = args.subprocess
|
TEST_IN_SUBPROCESS = args.subprocess
|
||||||
TEST_SAVE_XML = args.save_xml
|
TEST_SAVE_XML = args.save_xml
|
||||||
REPEAT_COUNT = args.repeat
|
REPEAT_COUNT = args.repeat
|
||||||
SEED = args.seed
|
SEED = args.seed
|
||||||
SHOWLOCALS = args.showlocals
|
SHOWLOCALS = args.showlocals
|
||||||
if not getattr(expecttest, "ACCEPT", False):
|
if not getattr(expecttest, "ACCEPT", False):
|
||||||
expecttest.ACCEPT = args.accept
|
expecttest.ACCEPT = args.accept
|
||||||
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||||
torch.manual_seed(SEED)
|
torch.manual_seed(SEED)
|
||||||
|
|
||||||
# CI Prefix path used only on CI environment
|
# CI Prefix path used only on CI environment
|
||||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||||
|
|
||||||
def wait_for_process(p, timeout=None):
|
def wait_for_process(p, timeout=None):
|
||||||
try:
|
try:
|
||||||
@ -1181,9 +1138,7 @@ def lint_test_case_extension(suite):
|
|||||||
return succeed
|
return succeed
|
||||||
|
|
||||||
|
|
||||||
def get_report_path(argv=None, pytest=False):
|
def get_report_path(argv=UNITTEST_ARGS, pytest=False):
|
||||||
if argv is None:
|
|
||||||
argv = UNITTEST_ARGS
|
|
||||||
test_filename = sanitize_test_filename(argv[0])
|
test_filename = sanitize_test_filename(argv[0])
|
||||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||||
test_report_path = os.path.join(test_report_path, test_filename)
|
test_report_path = os.path.join(test_report_path, test_filename)
|
||||||
@ -1234,11 +1189,7 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]:
|
|||||||
return test_collector_plugin.tests
|
return test_collector_plugin.tests
|
||||||
|
|
||||||
|
|
||||||
def run_tests(argv=None):
|
def run_tests(argv=UNITTEST_ARGS):
|
||||||
parse_cmd_line_args()
|
|
||||||
if argv is None:
|
|
||||||
argv = UNITTEST_ARGS
|
|
||||||
|
|
||||||
# import test files.
|
# import test files.
|
||||||
if SLOW_TESTS_FILE:
|
if SLOW_TESTS_FILE:
|
||||||
if os.path.exists(SLOW_TESTS_FILE):
|
if os.path.exists(SLOW_TESTS_FILE):
|
||||||
@ -1804,7 +1755,6 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
|
|||||||
if not isinstance(fn, type):
|
if not isinstance(fn, type):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
assert GRAPH_EXECUTOR
|
|
||||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||||
raise unittest.SkipTest(msg)
|
raise unittest.SkipTest(msg)
|
||||||
else:
|
else:
|
||||||
@ -2435,19 +2385,7 @@ def get_function_arglist(func):
|
|||||||
return inspect.getfullargspec(func).args
|
return inspect.getfullargspec(func).args
|
||||||
|
|
||||||
|
|
||||||
def set_rng_seed(seed=None):
|
def set_rng_seed(seed):
|
||||||
if seed is None:
|
|
||||||
if SEED is not None:
|
|
||||||
seed = SEED
|
|
||||||
else:
|
|
||||||
# Can't assert here: this function is called by TestCase.setUp() and some out of tree tests inherit from that class.
|
|
||||||
# So just print a warning and hardcode the seed.
|
|
||||||
seed = 1234
|
|
||||||
msg = ("set_rng_seed() was called without providing a seed and the command line "
|
|
||||||
f"arguments haven't been parsed so the seed will be set to {seed}. "
|
|
||||||
"To remove this warning make sure your test is run via run_tests() or "
|
|
||||||
"parse_cmd_line_args() is called before set_rng_seed() is called.")
|
|
||||||
warnings.warn(msg)
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
if TEST_NUMPY:
|
if TEST_NUMPY:
|
||||||
@ -3470,7 +3408,7 @@ class TestCase(expecttest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
check_if_enable(self)
|
check_if_enable(self)
|
||||||
set_rng_seed()
|
set_rng_seed(SEED)
|
||||||
|
|
||||||
# Save global check sparse tensor invariants state that can be
|
# Save global check sparse tensor invariants state that can be
|
||||||
# restored from tearDown:
|
# restored from tearDown:
|
||||||
|
@ -137,18 +137,16 @@ class Foo:
|
|||||||
f = Foo(10)
|
f = Foo(10)
|
||||||
f.bar = 1
|
f.bar = 1
|
||||||
|
|
||||||
|
foo_cpu_tensor = Foo(torch.randn(3, 3))
|
||||||
|
|
||||||
# Defer instantiation until the seed is set so that randn() returns the same
|
|
||||||
# values in all processes.
|
|
||||||
def create_collectives_object_test_list():
|
|
||||||
return [
|
|
||||||
{"key1": 3, "key2": 4, "key3": {"nested": True}},
|
|
||||||
f,
|
|
||||||
Foo(torch.randn(3, 3)),
|
|
||||||
"foo",
|
|
||||||
[1, 2, True, "string", [4, 5, "nested"]],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
COLLECTIVES_OBJECT_TEST_LIST = [
|
||||||
|
{"key1": 3, "key2": 4, "key3": {"nested": True}},
|
||||||
|
f,
|
||||||
|
foo_cpu_tensor,
|
||||||
|
"foo",
|
||||||
|
[1, 2, True, "string", [4, 5, "nested"]],
|
||||||
|
]
|
||||||
|
|
||||||
# Allowlist of distributed backends where profiling collectives is supported.
|
# Allowlist of distributed backends where profiling collectives is supported.
|
||||||
PROFILING_SUPPORTED_BACKENDS = [
|
PROFILING_SUPPORTED_BACKENDS = [
|
||||||
@ -398,6 +396,12 @@ class ControlFlowToyModel(nn.Module):
|
|||||||
return F.relu(self.lin1(x))
|
return F.relu(self.lin1(x))
|
||||||
|
|
||||||
|
|
||||||
|
DDP_NET = Net()
|
||||||
|
BN_NET = BatchNormNet()
|
||||||
|
BN_NET_NO_AFFINE = BatchNormNet(affine=False)
|
||||||
|
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
|
||||||
|
|
||||||
|
|
||||||
def get_timeout(test_id):
|
def get_timeout(test_id):
|
||||||
test_name = test_id.split(".")[-1]
|
test_name = test_id.split(".")[-1]
|
||||||
if test_name in CUSTOMIZED_TIMEOUT:
|
if test_name in CUSTOMIZED_TIMEOUT:
|
||||||
@ -590,13 +594,12 @@ class TestDistBackend(MultiProcessTestCase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs):
|
def _run(cls, rank, test_name, file_name, pipe, **kwargs):
|
||||||
if BACKEND == "nccl" and not torch.cuda.is_available():
|
if BACKEND == "nccl" and not torch.cuda.is_available():
|
||||||
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||||
self = cls(test_name)
|
self = cls(test_name)
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() < int(
|
if torch.cuda.is_available() and torch.cuda.device_count() < int(
|
||||||
self.world_size
|
self.world_size
|
||||||
@ -4284,7 +4287,7 @@ class DistributedTest:
|
|||||||
# as baseline
|
# as baseline
|
||||||
|
|
||||||
# cpu training setup
|
# cpu training setup
|
||||||
model = Net()
|
model = DDP_NET
|
||||||
|
|
||||||
# single gpu training setup
|
# single gpu training setup
|
||||||
model_gpu = copy.deepcopy(model)
|
model_gpu = copy.deepcopy(model)
|
||||||
@ -4339,7 +4342,7 @@ class DistributedTest:
|
|||||||
_group, _group_id, rank = self._init_global_test()
|
_group, _group_id, rank = self._init_global_test()
|
||||||
|
|
||||||
# cpu training setup
|
# cpu training setup
|
||||||
model_base = Net()
|
model_base = DDP_NET
|
||||||
|
|
||||||
# DDP-CPU training setup
|
# DDP-CPU training setup
|
||||||
model_DDP = copy.deepcopy(model_base)
|
model_DDP = copy.deepcopy(model_base)
|
||||||
@ -5488,7 +5491,7 @@ class DistributedTest:
|
|||||||
def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
|
def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
|
||||||
torch.manual_seed(31415)
|
torch.manual_seed(31415)
|
||||||
# Creates model and optimizer in default precision
|
# Creates model and optimizer in default precision
|
||||||
model = Net().cuda()
|
model = copy.deepcopy(DDP_NET).cuda()
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
|
||||||
|
|
||||||
# Creates a GradScaler once at the beginning of training.
|
# Creates a GradScaler once at the beginning of training.
|
||||||
@ -5573,7 +5576,7 @@ class DistributedTest:
|
|||||||
# as baseline
|
# as baseline
|
||||||
|
|
||||||
# cpu training setup
|
# cpu training setup
|
||||||
model = BatchNormNet() if affine else BatchNormNet(affine=False)
|
model = BN_NET if affine else BN_NET_NO_AFFINE
|
||||||
|
|
||||||
# single gpu training setup
|
# single gpu training setup
|
||||||
model_gpu = copy.deepcopy(model)
|
model_gpu = copy.deepcopy(model)
|
||||||
@ -5623,7 +5626,6 @@ class DistributedTest:
|
|||||||
def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
|
def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
|
||||||
learning_rate = 0.03
|
learning_rate = 0.03
|
||||||
|
|
||||||
DDP_NET = Net()
|
|
||||||
net = torch.nn.parallel.DistributedDataParallel(
|
net = torch.nn.parallel.DistributedDataParallel(
|
||||||
copy.deepcopy(DDP_NET).cuda(),
|
copy.deepcopy(DDP_NET).cuda(),
|
||||||
device_ids=[self.rank],
|
device_ids=[self.rank],
|
||||||
@ -5690,7 +5692,7 @@ class DistributedTest:
|
|||||||
learning_rate = 0.03
|
learning_rate = 0.03
|
||||||
|
|
||||||
net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
|
net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
|
||||||
Net().cuda(), device_ids=[self.rank]
|
copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
|
||||||
)
|
)
|
||||||
|
|
||||||
averager = create_averager()
|
averager = create_averager()
|
||||||
@ -5840,7 +5842,7 @@ class DistributedTest:
|
|||||||
bs_offset = int(rank * 2)
|
bs_offset = int(rank * 2)
|
||||||
global_bs = int(num_processes * 2)
|
global_bs = int(num_processes * 2)
|
||||||
|
|
||||||
model = nn.SyncBatchNorm(2, momentum=0.99)
|
model = ONLY_SBN_NET
|
||||||
model_gpu = copy.deepcopy(model).cuda(rank)
|
model_gpu = copy.deepcopy(model).cuda(rank)
|
||||||
model_DDP = nn.parallel.DistributedDataParallel(
|
model_DDP = nn.parallel.DistributedDataParallel(
|
||||||
model_gpu, device_ids=[rank]
|
model_gpu, device_ids=[rank]
|
||||||
@ -6050,7 +6052,6 @@ class DistributedTest:
|
|||||||
def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
|
def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
|
|
||||||
_group, _group_id, rank = self._init_global_test()
|
_group, _group_id, rank = self._init_global_test()
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
ONLY_SBN_NET.cuda(rank), device_ids=[rank]
|
ONLY_SBN_NET.cuda(rank), device_ids=[rank]
|
||||||
@ -6118,7 +6119,7 @@ class DistributedTest:
|
|||||||
def test_DistributedDataParallel_SyncBatchNorm_half(self):
|
def test_DistributedDataParallel_SyncBatchNorm_half(self):
|
||||||
_group, _group_id, rank = self._init_global_test()
|
_group, _group_id, rank = self._init_global_test()
|
||||||
|
|
||||||
model = BatchNormNet()
|
model = copy.deepcopy(BN_NET)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
@ -6134,7 +6135,7 @@ class DistributedTest:
|
|||||||
|
|
||||||
def _test_ddp_logging_data(self, is_gpu):
|
def _test_ddp_logging_data(self, is_gpu):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
model_DDP = Net()
|
model_DDP = copy.deepcopy(DDP_NET)
|
||||||
if is_gpu:
|
if is_gpu:
|
||||||
model_DDP = nn.parallel.DistributedDataParallel(
|
model_DDP = nn.parallel.DistributedDataParallel(
|
||||||
model_DDP.cuda(rank), device_ids=[rank]
|
model_DDP.cuda(rank), device_ids=[rank]
|
||||||
@ -6410,7 +6411,7 @@ class DistributedTest:
|
|||||||
BACKEND == "nccl", "nccl does not support DDP on CPU models"
|
BACKEND == "nccl", "nccl does not support DDP on CPU models"
|
||||||
)
|
)
|
||||||
def test_static_graph_api_cpu(self):
|
def test_static_graph_api_cpu(self):
|
||||||
model_DDP = nn.parallel.DistributedDataParallel(Net())
|
model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
|
||||||
expected_err = "should be called before training loop starts"
|
expected_err = "should be called before training loop starts"
|
||||||
with self.assertRaisesRegex(RuntimeError, expected_err):
|
with self.assertRaisesRegex(RuntimeError, expected_err):
|
||||||
local_bs = 2
|
local_bs = 2
|
||||||
@ -6643,7 +6644,7 @@ class DistributedTest:
|
|||||||
def _test_allgather_object(self, subgroup=None):
|
def _test_allgather_object(self, subgroup=None):
|
||||||
# Only set device for NCCL backend since it must use GPUs.
|
# Only set device for NCCL backend since it must use GPUs.
|
||||||
|
|
||||||
gather_objects = create_collectives_object_test_list()
|
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||||
|
|
||||||
backend = os.environ["BACKEND"]
|
backend = os.environ["BACKEND"]
|
||||||
if backend == "nccl":
|
if backend == "nccl":
|
||||||
@ -6687,7 +6688,7 @@ class DistributedTest:
|
|||||||
|
|
||||||
def _test_gather_object(self, pg=None):
|
def _test_gather_object(self, pg=None):
|
||||||
# Ensure stateful objects can be gathered
|
# Ensure stateful objects can be gathered
|
||||||
gather_objects = create_collectives_object_test_list()
|
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||||
my_rank = dist.get_rank(pg)
|
my_rank = dist.get_rank(pg)
|
||||||
|
|
||||||
backend = os.environ["BACKEND"]
|
backend = os.environ["BACKEND"]
|
||||||
@ -7257,7 +7258,7 @@ class DistributedTest:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
torch.cuda.set_device(self.rank)
|
torch.cuda.set_device(self.rank)
|
||||||
model_bn = BatchNormNet()
|
model_bn = BN_NET
|
||||||
model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
|
model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
|
||||||
copy.deepcopy(model_bn)
|
copy.deepcopy(model_bn)
|
||||||
).cuda(self.rank)
|
).cuda(self.rank)
|
||||||
@ -7553,7 +7554,7 @@ class DistributedTest:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def _test_broadcast_object_list(self, group=None):
|
def _test_broadcast_object_list(self, group=None):
|
||||||
gather_objects = create_collectives_object_test_list()
|
gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
|
||||||
|
|
||||||
# Only set device for NCCL backend since it must use GPUs.
|
# Only set device for NCCL backend since it must use GPUs.
|
||||||
# Case where rank != GPU device.
|
# Case where rank != GPU device.
|
||||||
@ -8277,11 +8278,10 @@ class DistributedTest:
|
|||||||
@require_backend_is_available({"gloo"})
|
@require_backend_is_available({"gloo"})
|
||||||
def test_scatter_object_list(self):
|
def test_scatter_object_list(self):
|
||||||
src_rank = 0
|
src_rank = 0
|
||||||
collectives_object_test_list = create_collectives_object_test_list()
|
|
||||||
scatter_list = (
|
scatter_list = (
|
||||||
collectives_object_test_list
|
COLLECTIVES_OBJECT_TEST_LIST
|
||||||
if self.rank == src_rank
|
if self.rank == src_rank
|
||||||
else [None for _ in collectives_object_test_list]
|
else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
|
||||||
)
|
)
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
scatter_list = scatter_list[:world_size]
|
scatter_list = scatter_list[:world_size]
|
||||||
@ -8294,8 +8294,8 @@ class DistributedTest:
|
|||||||
dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
|
dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
output_obj_list[0],
|
output_obj_list[0],
|
||||||
collectives_object_test_list[
|
COLLECTIVES_OBJECT_TEST_LIST[
|
||||||
self.rank % len(collectives_object_test_list)
|
self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# Ensure errors are raised upon incorrect arguments.
|
# Ensure errors are raised upon incorrect arguments.
|
||||||
@ -9981,7 +9981,7 @@ class DistributedTest:
|
|||||||
"Only Nccl & Gloo backend support DistributedDataParallel",
|
"Only Nccl & Gloo backend support DistributedDataParallel",
|
||||||
)
|
)
|
||||||
def test_sync_bn_logged(self):
|
def test_sync_bn_logged(self):
|
||||||
model = BatchNormNet()
|
model = BN_NET
|
||||||
rank = self.rank
|
rank = self.rank
|
||||||
# single gpu training setup
|
# single gpu training setup
|
||||||
model_gpu = model.cuda(rank)
|
model_gpu = model.cuda(rank)
|
||||||
|
Reference in New Issue
Block a user