Ensure TORCH_TRACE is run for Dynamo/Distributed tests (#139786)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139786
Approved by: https://github.com/bobrenjc93, https://github.com/c00w, https://github.com/anijain2305
ghstack dependencies: #139716
This commit is contained in:
Edward Z. Yang
2024-11-05 20:00:52 -08:00
committed by PyTorch MergeBot
parent 47446cb5f3
commit 4e647871d6
5 changed files with 28 additions and 2 deletions

View File

@ -24,6 +24,7 @@ from io import StringIO
from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable, Tuple
from unittest.mock import patch
from torch._logging._internal import trace_log
import torch
import torch._dynamo.test_case
import torch.cuda.nccl
@ -1348,6 +1349,8 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
trace_log.addHandler(logging.NullHandler())
# The rest is copypasta from MultiProcessTestCase._run
self = cls(test_name)
self.rank = rank