mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"
This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3. Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
This commit is contained in:
@ -32,7 +32,6 @@ import torch.nn as nn
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
from torch._logging._internal import trace_log
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import (
|
||||
FILE_SCHEMA,
|
||||
find_free_port,
|
||||
@ -672,7 +671,6 @@ class MultiProcessTestCase(TestCase):
|
||||
if methodName != "runTest":
|
||||
method_name = methodName
|
||||
super().__init__(method_name)
|
||||
self.seed = None
|
||||
try:
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
@ -717,20 +715,13 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
def _start_processes(self, proc) -> None:
|
||||
self.processes = []
|
||||
assert common_utils.SEED is not None
|
||||
for rank in range(int(self.world_size)):
|
||||
parent_conn, child_conn = torch.multiprocessing.Pipe()
|
||||
process = proc(
|
||||
target=self.__class__._run,
|
||||
name="process " + str(rank),
|
||||
args=(
|
||||
rank,
|
||||
self._current_test_name(),
|
||||
self.file_name,
|
||||
child_conn,
|
||||
),
|
||||
args=(rank, self._current_test_name(), self.file_name, child_conn),
|
||||
kwargs={
|
||||
"seed": common_utils.SEED,
|
||||
"fake_pg": getattr(self, "fake_pg", False),
|
||||
},
|
||||
)
|
||||
@ -784,12 +775,11 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
@classmethod
|
||||
def _run(
|
||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
|
||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
|
||||
) -> None:
|
||||
self = cls(test_name)
|
||||
self.rank = rank
|
||||
self.file_name = file_name
|
||||
self.seed = seed
|
||||
self.run_test(test_name, parent_pipe)
|
||||
|
||||
def run_test(self, test_name: str, parent_pipe) -> None:
|
||||
@ -808,9 +798,6 @@ class MultiProcessTestCase(TestCase):
|
||||
# Show full C++ stacktraces when a Python error originating from C++ is raised.
|
||||
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
|
||||
|
||||
if self.seed is not None:
|
||||
common_utils.set_rng_seed(self.seed)
|
||||
|
||||
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
||||
# We're retrieving a corresponding test and executing it.
|
||||
try:
|
||||
@ -1548,7 +1535,7 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
|
||||
|
||||
@classmethod
|
||||
def _run(
|
||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
|
||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
|
||||
) -> None:
|
||||
trace_log.addHandler(logging.NullHandler())
|
||||
|
||||
@ -1556,7 +1543,6 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
|
||||
self = cls(test_name)
|
||||
self.rank = rank
|
||||
self.file_name = file_name
|
||||
self.seed = seed
|
||||
self.run_test(test_name, parent_pipe)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user