mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] replace unittest.main with run_tests (#50451)
Summary: fix https://github.com/pytorch/pytorch/issues/50448. This replaces all `test/*.py` files with run_tests(). This PR does not address test files in the subdirectories because they seems unrelated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50451 Reviewed By: janeyx99 Differential Revision: D25899924 Pulled By: walterddr fbshipit-source-id: f7c861f0096624b2791ad6ef6a16b1c4895cce71
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a4383a69d4
commit
fc5db4265b
@ -1,6 +1,6 @@
|
||||
import torch.testing._internal.expecttest as expecttest
|
||||
from torch.testing._internal import expecttest
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
import unittest
|
||||
import string
|
||||
import textwrap
|
||||
import doctest
|
||||
@ -17,7 +17,7 @@ def text_lineno(draw):
|
||||
return (t, lineno)
|
||||
|
||||
|
||||
class TestExpectTest(expecttest.TestCase):
|
||||
class TestExpectTest(TestCase):
|
||||
@hypothesis.given(text_lineno())
|
||||
def test_nth_line_ref(self, t_lineno):
|
||||
t, lineno = t_lineno
|
||||
@ -103,4 +103,4 @@ def load_tests(loader, tests, ignore):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
@ -1,9 +1,8 @@
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
import contextlib
|
||||
import subprocess
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -16,7 +15,7 @@ def _jit_disabled():
|
||||
os.environ["PYTORCH_JIT"] = cur_env
|
||||
|
||||
|
||||
class TestJitDisabled(unittest.TestCase):
|
||||
class TestJitDisabled(TestCase):
|
||||
"""
|
||||
These tests are separate from the rest of the JIT tests because we need
|
||||
run a new subprocess and `import torch` with the correct environment
|
||||
@ -91,4 +90,4 @@ print("Didn't throw exception")
|
||||
self.compare_enabled_disabled(_program_string)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.backends.xnnpack
|
||||
import torch.utils.bundled_inputs
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
|
||||
from torch.utils.mobile_optimizer import *
|
||||
from torch.nn import functional as F
|
||||
@ -12,7 +13,7 @@ from torch.nn.modules.module import ModuleAttributeError
|
||||
|
||||
FileCheck = torch._C.FileCheck
|
||||
|
||||
class TestOptimizer(unittest.TestCase):
|
||||
class TestOptimizer(TestCase):
|
||||
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests."
|
||||
@ -430,4 +431,4 @@ class TestOptimizer(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
@ -1,9 +1,10 @@
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
import unittest
|
||||
import textwrap
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
@ -17,7 +18,7 @@ all_operators_with_namedtuple_return = {
|
||||
}
|
||||
|
||||
|
||||
class TestNamedTupleAPI(unittest.TestCase):
|
||||
class TestNamedTupleAPI(TestCase):
|
||||
|
||||
def test_native_functions_yaml(self):
|
||||
operators_found = set()
|
||||
@ -108,4 +109,4 @@ class TestNamedTupleAPI(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
@ -1,11 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
import inspect
|
||||
import functools
|
||||
import pprint
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.overrides import (
|
||||
handle_torch_function,
|
||||
has_torch_function,
|
||||
@ -880,4 +879,4 @@ class TestWrapTorchFunction(TestCase):
|
||||
self.assertEqual(f(A()), -1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
@ -1,5 +1,5 @@
|
||||
from unittest import main, skipIf
|
||||
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS
|
||||
from unittest import skipIf
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
||||
from tempfile import NamedTemporaryFile
|
||||
from torch.package import PackageExporter, PackageImporter
|
||||
from pathlib import Path
|
||||
@ -392,4 +392,4 @@ def load():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
run_tests()
|
||||
|
@ -4,9 +4,9 @@ import tempfile
|
||||
import torch
|
||||
import torch.utils.show_pickle
|
||||
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
||||
|
||||
class TestShowPickle(unittest.TestCase):
|
||||
class TestShowPickle(TestCase):
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
|
||||
def test_scripted_model(self):
|
||||
@ -31,4 +31,4 @@ class TestShowPickle(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs
|
||||
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
|
||||
|
||||
from torch.testing._internal.te_utils import CudaCodeGenCreated, CudaCodeGenExecuted, \
|
||||
LLVMCodeGenExecuted, SimpleIREvalExecuted
|
||||
@ -1647,4 +1647,4 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
self.assertEqual(ref, exp)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user