[BE] replace unittest.main with run_tests (#50451)

Summary:
fix https://github.com/pytorch/pytorch/issues/50448.

This replaces all `test/*.py` files with run_tests(). This PR does not address test files in the subdirectories because they seems unrelated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50451

Reviewed By: janeyx99

Differential Revision: D25899924

Pulled By: walterddr

fbshipit-source-id: f7c861f0096624b2791ad6ef6a16b1c4895cce71
This commit is contained in:
Rong Rong (AI Infra)
2021-01-13 10:30:17 -08:00
committed by Facebook GitHub Bot
parent a4383a69d4
commit fc5db4265b
8 changed files with 24 additions and 24 deletions

View File

@ -1,6 +1,6 @@
import torch.testing._internal.expecttest as expecttest
from torch.testing._internal import expecttest
from torch.testing._internal.common_utils import TestCase, run_tests
import unittest
import string
import textwrap
import doctest
@ -17,7 +17,7 @@ def text_lineno(draw):
return (t, lineno)
class TestExpectTest(expecttest.TestCase):
class TestExpectTest(TestCase):
@hypothesis.given(text_lineno())
def test_nth_line_ref(self, t_lineno):
t, lineno = t_lineno
@ -103,4 +103,4 @@ def load_tests(loader, tests, ignore):
if __name__ == '__main__':
unittest.main()
run_tests()

View File

@ -1,9 +1,8 @@
import unittest
import sys
import os
import contextlib
import subprocess
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
@contextlib.contextmanager
@ -16,7 +15,7 @@ def _jit_disabled():
os.environ["PYTORCH_JIT"] = cur_env
class TestJitDisabled(unittest.TestCase):
class TestJitDisabled(TestCase):
"""
These tests are separate from the rest of the JIT tests because we need
run a new subprocess and `import torch` with the correct environment
@ -91,4 +90,4 @@ print("Didn't throw exception")
self.compare_enabled_disabled(_program_string)
if __name__ == '__main__':
unittest.main()
run_tests()

View File

@ -3,6 +3,7 @@ import torch
import torch.nn as nn
import torch.backends.xnnpack
import torch.utils.bundled_inputs
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
from torch.utils.mobile_optimizer import *
from torch.nn import functional as F
@ -12,7 +13,7 @@ from torch.nn.modules.module import ModuleAttributeError
FileCheck = torch._C.FileCheck
class TestOptimizer(unittest.TestCase):
class TestOptimizer(TestCase):
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
@ -430,4 +431,4 @@ class TestOptimizer(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
run_tests()

View File

@ -1,9 +1,10 @@
import os
import re
import yaml
import unittest
import textwrap
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from collections import namedtuple
@ -17,7 +18,7 @@ all_operators_with_namedtuple_return = {
}
class TestNamedTupleAPI(unittest.TestCase):
class TestNamedTupleAPI(TestCase):
def test_native_functions_yaml(self):
operators_found = set()
@ -108,4 +109,4 @@ class TestNamedTupleAPI(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
run_tests()

View File

@ -1,11 +1,10 @@
import torch
import numpy as np
import unittest
import inspect
import functools
import pprint
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.overrides import (
handle_torch_function,
has_torch_function,
@ -880,4 +879,4 @@ class TestWrapTorchFunction(TestCase):
self.assertEqual(f(A()), -1)
if __name__ == '__main__':
unittest.main()
run_tests()

View File

@ -1,5 +1,5 @@
from unittest import main, skipIf
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS
from unittest import skipIf
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
from tempfile import NamedTemporaryFile
from torch.package import PackageExporter, PackageImporter
from pathlib import Path
@ -392,4 +392,4 @@ def load():
if __name__ == '__main__':
main()
run_tests()

View File

@ -4,9 +4,9 @@ import tempfile
import torch
import torch.utils.show_pickle
from torch.testing._internal.common_utils import IS_WINDOWS
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
class TestShowPickle(unittest.TestCase):
class TestShowPickle(TestCase):
@unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
def test_scripted_model(self):
@ -31,4 +31,4 @@ class TestShowPickle(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
run_tests()

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torch import nn
import unittest
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
from torch.testing._internal.te_utils import CudaCodeGenCreated, CudaCodeGenExecuted, \
LLVMCodeGenExecuted, SimpleIREvalExecuted
@ -1647,4 +1647,4 @@ class TestTensorExprFuser(BaseTestClass):
self.assertEqual(ref, exp)
if __name__ == '__main__':
unittest.main()
run_tests()