Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"

This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3.

Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
This commit is contained in:
PyTorch MergeBot
2025-08-04 20:37:39 +00:00
parent d4109a0f99
commit 356ac3103a
16 changed files with 109 additions and 255 deletions

View File

@ -32,7 +32,6 @@ import torch.nn as nn
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch._logging._internal import trace_log
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
find_free_port,
@ -672,7 +671,6 @@ class MultiProcessTestCase(TestCase):
if methodName != "runTest":
method_name = methodName
super().__init__(method_name)
self.seed = None
try:
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
@ -717,20 +715,13 @@ class MultiProcessTestCase(TestCase):
def _start_processes(self, proc) -> None:
self.processes = []
assert common_utils.SEED is not None
for rank in range(int(self.world_size)):
parent_conn, child_conn = torch.multiprocessing.Pipe()
process = proc(
target=self.__class__._run,
name="process " + str(rank),
args=(
rank,
self._current_test_name(),
self.file_name,
child_conn,
),
args=(rank, self._current_test_name(), self.file_name, child_conn),
kwargs={
"seed": common_utils.SEED,
"fake_pg": getattr(self, "fake_pg", False),
},
)
@ -784,12 +775,11 @@ class MultiProcessTestCase(TestCase):
@classmethod
def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
) -> None:
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.seed = seed
self.run_test(test_name, parent_pipe)
def run_test(self, test_name: str, parent_pipe) -> None:
@ -808,9 +798,6 @@ class MultiProcessTestCase(TestCase):
# Show full C++ stacktraces when a Python error originating from C++ is raised.
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
if self.seed is not None:
common_utils.set_rng_seed(self.seed)
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retrieving a corresponding test and executing it.
try:
@ -1548,7 +1535,7 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
@classmethod
def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
) -> None:
trace_log.addHandler(logging.NullHandler())
@ -1556,7 +1543,6 @@ class DynamoDistributedMultiProcTestCase(DistributedTestBase):
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.seed = seed
self.run_test(test_name, parent_pipe)