mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add __main__ guards to jit tests (#154725)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/clee2000
This commit is contained in:
committed by
PyTorch MergeBot
parent
f810e98143
commit
bf7e290854
@ -2,18 +2,13 @@
|
||||
|
||||
import torch
|
||||
from torch._C import parse_ir
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
TemporaryFileName,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestAliasAnalysis(JitTestCase):
|
||||
def test_becomes_wildcard_annotations(self):
|
||||
graph_str = """
|
||||
@ -154,3 +149,7 @@ class TestAliasAnalysis(JitTestCase):
|
||||
mod = ModuleWrapper(module_list)
|
||||
mod = torch.jit.script(mod)
|
||||
mod(torch.zeros((2, 2)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -16,6 +16,7 @@ from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.jit import Future
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
|
||||
|
||||
|
||||
@ -547,8 +548,4 @@ class TestAsync(JitTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
|
||||
|
||||
class TestAtenPow(TestCase):
|
||||
@ -99,3 +99,7 @@ class TestAtenPow(TestCase):
|
||||
self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
|
||||
# zero base and negative exponent case that should trigger RunTimeError
|
||||
self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,17 +4,10 @@ from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultAttr(JitTestCase):
|
||||
def test_getattr_with_default(self):
|
||||
class A(torch.nn.Module):
|
||||
@ -66,3 +59,7 @@ class TestGetDefaultAttr(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,7 +4,10 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -145,3 +148,7 @@ class TestAutodiffJit(JitTestCase):
|
||||
self.assertEqual(x_s.requires_grad, x.requires_grad)
|
||||
self.assertEqual(y_s.requires_grad, y.requires_grad)
|
||||
self.assertEqual(z_s.requires_grad, z.requires_grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -20,20 +20,13 @@ sys.path.append(pytorch_test_dir)
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import (
|
||||
disable_autodiff_subgraph_inlining,
|
||||
JitTestCase,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
||||
)
|
||||
@ -589,3 +582,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
FileCheck().check("= prim::DifferentiableGraph").check(
|
||||
"with prim::DifferentiableGraph"
|
||||
).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._awaits import _Await as Await
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -390,3 +391,7 @@ class TestAwait(JitTestCase):
|
||||
sm = torch.jit.load(iofile)
|
||||
script_out_load = sm(inp)
|
||||
self.assertTrue(torch.allclose(expected, script_out_load))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,7 +7,11 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
|
||||
|
||||
# hacky way to skip these tests in fbcode:
|
||||
@ -28,13 +32,6 @@ else:
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
"""
|
||||
Unit Tests for Nnapi backend with delegate
|
||||
Inherits most tests from TestNNAPI, which loads Android NNAPI models
|
||||
@ -139,3 +136,7 @@ method_compile_spec must use the following format:
|
||||
def tearDown(self):
|
||||
# Change dtype back to default (Otherwise, other unit tests will complain)
|
||||
torch.set_default_dtype(self.default_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -15,6 +15,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
skipIfRocm,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
@ -25,13 +26,6 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def to_test_backend(module, method_compile_spec):
|
||||
return torch._C._jit_to_backend(
|
||||
@ -822,3 +816,7 @@ class AddedAttributesTest(JitBackendTestCase):
|
||||
)
|
||||
self.assertEqual(pre_bundled, post_bundled)
|
||||
self.assertEqual(post_bundled, post_load)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,24 +2,19 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestBatchMM(JitTestCase):
|
||||
@staticmethod
|
||||
def _get_test_tensors(n: int):
|
||||
return [
|
||||
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
|
||||
if x % 2 == 0
|
||||
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
|
||||
(
|
||||
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
|
||||
if x % 2 == 0
|
||||
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
|
||||
)
|
||||
for x in range(n)
|
||||
]
|
||||
|
||||
@ -288,3 +283,7 @@ class TestBatchMM(JitTestCase):
|
||||
FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
|
||||
"prim::MMBatchSide"
|
||||
).run(test_batch_mm.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,17 +13,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestBuiltins(JitTestCase):
|
||||
"""
|
||||
Tests for TorchScript support of Python builtin functions.
|
||||
@ -299,3 +292,7 @@ class TestTensorBuiltins(JitTestCase):
|
||||
self.assertEqual(
|
||||
test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -18,18 +18,14 @@ sys.path.append(pytorch_test_dir)
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch.testing._internal.jit_utils
|
||||
from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_SANDCASTLE,
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestClassType(JitTestCase):
|
||||
def test_reference_semantics(self):
|
||||
"""
|
||||
@ -1667,3 +1663,7 @@ class TestClassType(JitTestCase):
|
||||
for fn in (fn_a, fn_b, fn_c, fn_d, fn_e):
|
||||
with self.assertRaisesRegex(RuntimeError, error_message_regex):
|
||||
torch.jit.script(fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,7 +8,7 @@ from textwrap import dedent
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
|
||||
@ -617,3 +617,7 @@ class TestComplex(JitTestCase):
|
||||
scripted = torch.jit.script(op)
|
||||
jit_result = scripted(x, y)
|
||||
self.assertEqual(eager_result, jit_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,7 +13,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
run_tests,
|
||||
set_default_dtype,
|
||||
suppress_warnings,
|
||||
)
|
||||
@ -105,4 +104,7 @@ class TestComplexity(JitTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -22,16 +22,10 @@ skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
activations = [
|
||||
F.celu,
|
||||
F.elu,
|
||||
@ -204,3 +198,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
inp = torch.randn(N, C, H, W)
|
||||
self.run_pass("inplace_to_functional_activation", frozen_model.graph)
|
||||
self.assertEqual(model(inp), frozen_model(inp))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,6 +12,7 @@ from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import (
|
||||
NoTest,
|
||||
raise_on_run_directly,
|
||||
skipCUDANonDefaultStreamIf,
|
||||
skipIfRocm,
|
||||
TEST_CUDA,
|
||||
@ -36,13 +37,6 @@ if TEST_CUDA:
|
||||
torch.ones(1).cuda() # initialize cuda context
|
||||
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestCUDA(JitTestCase):
|
||||
"""
|
||||
@ -698,3 +692,7 @@ class TestCUDA(JitTestCase):
|
||||
FileCheck().check("cuda::_maybe_exchange_device(").run(g)
|
||||
torch._C._jit_pass_inline(g)
|
||||
FileCheck().check("cuda::_maybe_exchange_device(").run(g)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def canonical(graph):
|
||||
return torch._C._jit_pass_canonicalize(graph).str(False)
|
||||
|
||||
@ -151,3 +144,7 @@ graph(%x.1 : Tensor):
|
||||
def test_where_no_scalar(self):
|
||||
x = torch.rand(1, 3, 224, 224)
|
||||
torch.ops.aten.where(x > 0.5, -1.5, 1.5) # does not raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,17 +12,10 @@ import torch.nn.parallel as dp
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestDataParallel(JitTestCase):
|
||||
class Mpy(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -158,3 +151,7 @@ class TestDataParallel(JitTestCase):
|
||||
x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
|
||||
r1_forward = replica[1](x1)
|
||||
self.assertEqual(first_forward, r1_forward)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
from hypothesis import given, settings, strategies as st
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -168,3 +169,7 @@ class TestDataclasses(JitTestCase):
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
torch.jit.script(f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -73,3 +74,7 @@ class TestDCE(JitTestCase):
|
||||
torch._C._jit_pass_dce_graph(fn_s.graph)
|
||||
|
||||
FileCheck().check("aten::add_").run(fn_s.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -19,3 +19,10 @@ class TestDecorator(JitTestCase):
|
||||
fn = my_function_a
|
||||
fx = torch.jit.script(fn)
|
||||
self.assertEqual(fn(1.0), fx(1.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ from itertools import product
|
||||
|
||||
import torch
|
||||
from torch.jit._passes._property_propagation import apply_input_props_using_example
|
||||
from torch.testing._internal.common_utils import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TEST_CUDA
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -14,13 +14,6 @@ try:
|
||||
except ImportError:
|
||||
models = None
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceAnalysis(JitTestCase):
|
||||
@classmethod
|
||||
@ -336,3 +329,7 @@ class TestDeviceAnalysis(JitTestCase):
|
||||
test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn
|
||||
)
|
||||
self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -17,7 +17,11 @@ from torch.testing._internal.common_methods_invocations import (
|
||||
sample_inputs_conv2d,
|
||||
SampleInput,
|
||||
)
|
||||
from torch.testing._internal.common_utils import first_sample, set_default_dtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
first_sample,
|
||||
raise_on_run_directly,
|
||||
set_default_dtype,
|
||||
)
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
@ -27,14 +31,6 @@ Dtype Analysis relies on symbolic shape analysis, which is still in beta
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
custom_rules_works_list = {
|
||||
"nn.functional.adaptive_avg_pool1d",
|
||||
"nn.functional.adaptive_avg_pool2d",
|
||||
@ -386,3 +382,6 @@ class TestDtypeCustomRules(TestDtypeBase):
|
||||
TestDtypeCustomRulesCPU = None
|
||||
# This creates TestDtypeCustomRulesCPU
|
||||
instantiate_device_type_tests(TestDtypeCustomRules, globals(), only_for=("cpu",))
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,17 +12,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestEnum(JitTestCase):
|
||||
def test_enum_value_types(self):
|
||||
class IntEnum(Enum):
|
||||
@ -358,3 +351,7 @@ class TestEnum(JitTestCase):
|
||||
@torch.jit.script
|
||||
def is_red(x: Color) -> bool:
|
||||
return x == Color.RED
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -197,3 +197,10 @@ class TestException(TestCase):
|
||||
"jit.myexception.MyKeyError: This is a user defined key error",
|
||||
):
|
||||
fn()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
set_default_dtype,
|
||||
skipCUDAMemoryLeakCheckIf,
|
||||
skipIfTorchDynamo,
|
||||
@ -32,13 +33,6 @@ except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None
|
||||
|
||||
|
||||
@ -55,7 +49,7 @@ class TestFreezing(JitTestCase):
|
||||
self.a = 1 # folded
|
||||
self.b = 1.2 # folded
|
||||
self.c = "hello" # folded
|
||||
self.c2 = "hi\xA1" # not folded
|
||||
self.c2 = "hi\xa1" # not folded
|
||||
self.d = [1, 1] # folded
|
||||
self.e = [1.0, 1.1] # folded
|
||||
self.f = ["hello", "world"] # folded
|
||||
@ -67,7 +61,7 @@ class TestFreezing(JitTestCase):
|
||||
torch.tensor([5.5], requires_grad=True),
|
||||
) # folded
|
||||
self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]}
|
||||
self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]}
|
||||
self.h2 = {"layer\xb1": [torch.tensor([8.8], requires_grad=True)]}
|
||||
self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded
|
||||
self.ts = [
|
||||
torch.tensor([1.0, 2.0], requires_grad=True),
|
||||
@ -3461,3 +3455,7 @@ class TestMKLDNNReinplacing(JitTestCase):
|
||||
mod = self.freezeAndConvert(mod_eager)
|
||||
FileCheck().check("aten::add_").run(mod.graph)
|
||||
self.checkResults(mod_eager, mod)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestFunctionalBlocks(JitTestCase):
|
||||
def test_subgraph_creation(self):
|
||||
def fn(x, y, z):
|
||||
@ -54,3 +47,7 @@ class TestFunctionalBlocks(JitTestCase):
|
||||
FileCheck().check("add").check("add_").check_not("mul").check(
|
||||
"FunctionalGraph"
|
||||
).run(graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -19,3 +20,7 @@ class TestFuserCommon(JitTestCase):
|
||||
# test fallback when optimization is not applicable
|
||||
y = fn(torch.randn(5, requires_grad=rq))
|
||||
self.assertEqual(y.requires_grad, rq)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit_fuser_te.py")
|
||||
|
@ -6,18 +6,13 @@ import unittest
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfLegacyJitExecutor,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestGenerator(JitTestCase):
|
||||
# torch.jit.trace does not properly capture the generator manual seed
|
||||
# and thus is non deterministic even if the generator is manually seeded
|
||||
@ -193,3 +188,7 @@ class TestGenerator(JitTestCase):
|
||||
except: # noqa: B001, E722
|
||||
print(loaded_module.forward.code)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -3,6 +3,7 @@
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -59,3 +60,7 @@ class TestGraphRewritePasses(JitTestCase):
|
||||
FileCheck().check_not("aten::linear").run(model.graph)
|
||||
# make sure it runs
|
||||
model(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestHash(JitTestCase):
|
||||
def test_hash_tuple(self):
|
||||
def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
|
||||
@ -115,3 +108,7 @@ class TestHash(JitTestCase):
|
||||
self.checkScript(fn, (gpu0, gpu1))
|
||||
self.checkScript(fn, (gpu0, cpu))
|
||||
self.checkScript(fn, (cpu, cpu))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -33,17 +33,10 @@ from jit.test_hooks_modules import (
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests for JIT forward hooks and pre-hooks
|
||||
class TestHooks(JitTestCase):
|
||||
def test_module_no_forward_input(self):
|
||||
@ -393,3 +386,7 @@ class TestHooks(JitTestCase):
|
||||
r"Received type: 'str'. Expected type: 'Tuple\[str\]'",
|
||||
):
|
||||
torch.jit.script(m)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -528,3 +528,9 @@ def create_submodule_forward_single_input_return_not_tupled():
|
||||
m.submodule.register_forward_hook(forward_hook)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This file is a collection of utils, it should be imported not executed directly"
|
||||
)
|
||||
|
@ -11,17 +11,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests that Python slice class is supported in TorchScript
|
||||
class TestIgnorableArgs(JitTestCase):
|
||||
def test_slice_ignorable_args_for_slice(self):
|
||||
@ -61,3 +54,7 @@ class TestIgnorableArgs(JitTestCase):
|
||||
torch.add(x, y, out=y)
|
||||
|
||||
FileCheck().check("torch.add(x, y, out=y)").run(fn.code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,17 +11,10 @@ import torch
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestIgnoreContextManager(JitTestCase):
|
||||
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
||||
def test_with_ignore_context_manager_with_inp_out(self):
|
||||
@ -103,3 +96,7 @@ class TestIgnoreContextManager(JitTestCase):
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 5)
|
||||
self.assertEqual(s(), model())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,17 +11,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests for torch.jit.isinstance
|
||||
class TestIsinstance(JitTestCase):
|
||||
def test_int(self):
|
||||
@ -354,3 +347,7 @@ class TestIsinstance(JitTestCase):
|
||||
# Should not throw "Boolean value of Tensor with more than
|
||||
# one value is ambiguous" error
|
||||
torch._jit_internal.check_empty_containers(torch.rand(2, 3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,17 +11,10 @@ from torch.testing._internal import jit_utils
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests various JIT-related utility functions.
|
||||
class TestJitUtils(JitTestCase):
|
||||
# Tests that POSITIONAL_OR_KEYWORD arguments are captured.
|
||||
@ -116,3 +109,7 @@ class TestJitUtils(JitTestCase):
|
||||
with jit_utils.NoTracerWarnContextManager():
|
||||
self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
|
||||
self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -19,18 +19,14 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TEST_CUDA
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestList(JitTestCase):
|
||||
def test_list_bool_conversion(self):
|
||||
def if_predicate(l: List[int]):
|
||||
@ -1825,7 +1821,7 @@ class TestDict(JitTestCase):
|
||||
def test_popitem(self):
|
||||
@torch.jit.script
|
||||
def popitem(
|
||||
x: Dict[str, Tensor]
|
||||
x: Dict[str, Tensor],
|
||||
) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]:
|
||||
item = x.popitem()
|
||||
return item, x
|
||||
@ -2996,3 +2992,7 @@ class TestScriptList(JitTestCase):
|
||||
for i in range(300):
|
||||
test = Test()
|
||||
test_script = torch.jit.script(test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestLogging(JitTestCase):
|
||||
def test_bump_numeric_counter(self):
|
||||
class ModuleThatLogs(torch.jit.ScriptModule):
|
||||
@ -122,3 +115,7 @@ class TestLogging(JitTestCase):
|
||||
def test_logging_levels_set(self):
|
||||
torch._C._jit_set_logging_option("foo")
|
||||
self.assertEqual("foo", torch._C._jit_get_logging_option())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,7 +12,7 @@ import torch.testing._internal.jit_utils
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
from torch import jit
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import freeze_rng_state
|
||||
from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
|
||||
|
||||
|
||||
@ -20,13 +20,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestMisc(JitTestCase):
|
||||
def test_joined_str(self):
|
||||
@ -129,7 +122,7 @@ class TestMisc(JitTestCase):
|
||||
def test_subexpression_Tuple_int_int_Future(self):
|
||||
@torch.jit.script
|
||||
def fn(
|
||||
x: Tuple[int, int, torch.jit.Future[int]]
|
||||
x: Tuple[int, int, torch.jit.Future[int]],
|
||||
) -> Tuple[int, torch.jit.Future[int]]:
|
||||
return x[0], x[2]
|
||||
|
||||
@ -147,7 +140,7 @@ class TestMisc(JitTestCase):
|
||||
def test_subexpression_Optional(self):
|
||||
@torch.jit.script
|
||||
def fn(
|
||||
x: Optional[Dict[int, torch.jit.Future[int]]]
|
||||
x: Optional[Dict[int, torch.jit.Future[int]]],
|
||||
) -> Optional[torch.jit.Future[int]]:
|
||||
if x is not None:
|
||||
return x[0]
|
||||
@ -504,3 +497,7 @@ class TestMisc(JitTestCase):
|
||||
self.assertTrue(len(complex_indices) > 0)
|
||||
self.assertTrue(len(Scalar_indices) > 0)
|
||||
self.assertTrue(complex_indices[0] > Scalar_indices[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,24 +11,19 @@ from torch.testing._internal.common_utils import (
|
||||
enable_profiling_mode_for_profiling_tests,
|
||||
GRAPH_EXECUTOR,
|
||||
ProfilingMode,
|
||||
raise_on_run_directly,
|
||||
set_default_dtype,
|
||||
slowTest,
|
||||
suppress_warnings,
|
||||
)
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import slowTest, suppress_warnings
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
|
||||
@ -84,7 +79,7 @@ class TestModels(JitTestCase):
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf) x 32 x 32
|
||||
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||
nn.Tanh()
|
||||
nn.Tanh(),
|
||||
# state size. (nc) x 64 x 64
|
||||
)
|
||||
|
||||
@ -754,3 +749,7 @@ class TestModels(JitTestCase):
|
||||
m = self.createFunctionFromGraph(g)
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
self.assertEqual(outputs, m(*inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -5,6 +5,7 @@ import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -12,13 +13,6 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModuleAPIs(JitTestCase):
|
||||
def test_default_state_dict_methods(self):
|
||||
@ -141,3 +135,7 @@ class TestModuleAPIs(JitTestCase):
|
||||
self.assertFalse(m2.sub.customized_load_state_dict_called)
|
||||
m2.load_state_dict(state_dict)
|
||||
self.assertTrue(m2.sub.customized_load_state_dict_called)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,6 +7,7 @@ from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -14,13 +15,6 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModuleContainers(JitTestCase):
|
||||
def test_sequential_intermediary_types(self):
|
||||
@ -756,3 +750,7 @@ class TestModuleContainers(JitTestCase):
|
||||
)
|
||||
|
||||
self.checkModule(MyModule(), (torch.ones(1),))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,6 +8,7 @@ from typing import Any, List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -15,13 +16,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class OrigModule(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
@ -701,3 +695,7 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Could not compile"):
|
||||
scripted_mod = torch.jit.script(TestModule())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,6 +4,7 @@ import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -11,13 +12,6 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModules(JitTestCase):
|
||||
def test_script_module_with_constants_list(self):
|
||||
@ -36,3 +30,7 @@ class TestModules(JitTestCase):
|
||||
self.x = 0
|
||||
|
||||
self.checkModule(Net(), (torch.randn(5),))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,17 +2,10 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestOpDecompositions(JitTestCase):
|
||||
def test_op_decomposition(self):
|
||||
def foo(x):
|
||||
@ -42,3 +35,7 @@ class TestOpDecompositions(JitTestCase):
|
||||
FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph)
|
||||
x = torch.rand([4])
|
||||
self.assertEqual(foo(x), torch.square(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, skipIfNoXNNPACK
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -263,3 +263,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
conv2d_activation=F.relu,
|
||||
conv2d_activation_kind="aten::relu",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,17 +4,10 @@
|
||||
import torch
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
from torch import nn
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestParametrization(JitTestCase):
|
||||
# Define some parametrization
|
||||
class Symmetric(nn.Module):
|
||||
@ -68,3 +61,7 @@ class TestParametrization(JitTestCase):
|
||||
# Check the scripting process throws an error when caching
|
||||
with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"):
|
||||
scripted_model = torch.jit.trace_module(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
|
||||
from torch.testing._internal.common_utils import NoTest
|
||||
from torch.testing._internal.common_utils import NoTest, raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -21,13 +21,6 @@ if not _IS_MONKEYTYPE_INSTALLED:
|
||||
)
|
||||
JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPDT(JitTestCase):
|
||||
"""
|
||||
@ -896,3 +889,7 @@ class TestPDT(JitTestCase):
|
||||
torch.ones(1),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,17 +6,10 @@ from typing import Callable, List
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPeephole(JitTestCase):
|
||||
def test_peephole_with_writes(self):
|
||||
def test_write(x):
|
||||
@ -890,3 +883,7 @@ class TestPeephole(JitTestCase):
|
||||
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check("aten::slice").run(foo.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,7 +4,10 @@ import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -13,14 +16,6 @@ sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestProfiler(JitTestCase):
|
||||
def setUp(self):
|
||||
@ -284,3 +279,7 @@ class TestProfiler(JitTestCase):
|
||||
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,17 +2,10 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TestPythonBindings\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPythonBindings(JitTestCase):
|
||||
def test_cu_get_functions(self):
|
||||
@torch.jit.script
|
||||
@ -114,3 +107,7 @@ graph(%p207 : Tensor,
|
||||
graph3 = torch._C.parse_ir(ir)
|
||||
graph3 = torch._C._jit_pass_canonicalize(graph3, False)
|
||||
FileCheck().check_not("%p207").run(graph3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,6 +7,7 @@ import tempfile
|
||||
from textwrap import dedent
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
|
||||
@ -14,13 +15,6 @@ from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def get_fn(file_name, script_path):
|
||||
import importlib.util
|
||||
@ -473,3 +467,7 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
|
||||
s = torch.rand(1)
|
||||
self.assertTrue(foo(s))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,18 +6,10 @@ import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPythonIr(JitTestCase):
|
||||
def test_param_strides(self):
|
||||
def trace_me(arg):
|
||||
@ -100,3 +92,7 @@ class TestPythonIr(JitTestCase):
|
||||
|
||||
FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
|
||||
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -20,20 +20,13 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import (
|
||||
_tmp_donotuse_dont_inline_everything,
|
||||
JitTestCase,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestRecursiveScript(JitTestCase):
|
||||
def test_inferred_nonetype(self):
|
||||
class M(nn.Module):
|
||||
@ -799,3 +792,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
# ScriptModule should correctly reflect the override.
|
||||
s = torch.jit.script(m)
|
||||
self.assertEqual(s.i_am_ignored(), "new")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,17 +11,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestRemoveMutation(JitTestCase):
|
||||
def test_aten_inplace(self):
|
||||
def test_not_new_alias(x):
|
||||
@ -318,3 +311,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
self.run_pass("remove_mutation", mod_script.forward.graph)
|
||||
FileCheck().check("aten::add_").run(test_multiple_uses.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,7 +8,11 @@ from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
TemporaryFileName,
|
||||
)
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -17,14 +21,6 @@ sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestSaveLoad(JitTestCase):
|
||||
def test_different_modules(self):
|
||||
"""
|
||||
@ -1197,3 +1193,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)
|
||||
|
||||
self.assertEqual(extra_files, re_extra_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -17,17 +17,10 @@ import torch
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestSaveLoadForOpVersion(JitTestCase):
|
||||
# Helper that returns the module after saving and loading
|
||||
def _save_load_module(self, m):
|
||||
@ -617,3 +610,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
self.assertTrue(output.size(dim=0) == 100)
|
||||
# "Upgraded" model should match the new version output
|
||||
self.assertEqual(output, output_current)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ from torch import nn
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class Sequence(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -115,3 +108,7 @@ class TestScriptProfile(JitTestCase):
|
||||
p.enable()
|
||||
p.disable()
|
||||
self.assertEqual(p.dump_string(), "")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,17 +11,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
# NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
|
||||
# reassigning a non-empty Tuple to an attribute previously typed
|
||||
@ -363,3 +356,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests that Python slice class is supported in TorchScript
|
||||
class TestSlice(JitTestCase):
|
||||
def test_slice_kwarg(self):
|
||||
@ -178,3 +171,7 @@ class TestSlice(JitTestCase):
|
||||
self.assertEqual(result2[0].identifier, "B")
|
||||
self.assertEqual(result2[1].identifier, "C")
|
||||
self.assertEqual(result2[2].identifier, "D")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,7 +4,11 @@ import io
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
TEST_MKL,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -118,3 +122,7 @@ class TestSparse(JitTestCase):
|
||||
loaded_result = loaded_model.forward(x)
|
||||
|
||||
self.assertEqual(expected_result.to_dense(), loaded_result.to_dense())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestStringFormatting(JitTestCase):
|
||||
def test_modulo_operator(self):
|
||||
def fn(dividend: int, divisor: int) -> int:
|
||||
@ -199,3 +192,7 @@ class TestStringFormatting(JitTestCase):
|
||||
'"%a in template" % arg1',
|
||||
):
|
||||
fn("foo")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -9,18 +9,10 @@ import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
|
||||
from torch.testing._internal.common_utils import make_tensor
|
||||
from torch.testing._internal.common_utils import make_tensor, raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# XXX: still in prototype
|
||||
class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
def setUp(self):
|
||||
@ -819,3 +811,7 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
input.setType(input.type().with_sizes([1, 5, 8]))
|
||||
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
|
||||
self.assertEqual(next(foo.graph.outputs()).type().symbolic_sizes(), [5, 8])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -9,17 +9,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTensorCreationOps(JitTestCase):
|
||||
"""
|
||||
A suite of tests for ops that create tensors.
|
||||
@ -78,3 +71,7 @@ class TestTensorCreationOps(JitTestCase):
|
||||
assert indices.dtype == torch.int32
|
||||
|
||||
self.checkScript(tril_indices, (3, 3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,17 +10,10 @@ import torch
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTensorMethods(JitTestCase):
|
||||
def test_getitem(self):
|
||||
def tensor_getitem(inp: torch.Tensor):
|
||||
@ -41,3 +34,7 @@ class TestTensorMethods(JitTestCase):
|
||||
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"
|
||||
):
|
||||
torch.jit.script(tensor_getitem_invalid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,7 +8,10 @@ import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -19,14 +22,6 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("skipping as a precaution")
|
||||
class TestTorchbind(JitTestCase):
|
||||
def setUp(self):
|
||||
@ -463,3 +458,7 @@ class TestTorchbind(JitTestCase):
|
||||
return obj.decrement()
|
||||
|
||||
self.checkScript(gn, ())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -29,6 +29,7 @@ from torch.testing._internal.common_cuda import with_tf32_off
|
||||
from torch.testing._internal.common_utils import (
|
||||
enable_profiling_mode_for_profiling_tests,
|
||||
IS_SANDCASTLE,
|
||||
raise_on_run_directly,
|
||||
skipIfCompiledWithoutNumpy,
|
||||
skipIfCrossRef,
|
||||
skipIfTorchDynamo,
|
||||
@ -46,14 +47,6 @@ from torch.testing._internal.jit_utils import (
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
class TestTracer(JitTestCase):
|
||||
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
||||
@ -2826,3 +2819,7 @@ class TestMixTracingScripting(JitTestCase):
|
||||
for n in fn_t.graph.nodes():
|
||||
if n.kind() == "prim::CallFunction":
|
||||
self.assertTrue(n.output().isCompleteTensor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,18 +10,13 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import suppress_warnings
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
suppress_warnings,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTypeSharing(JitTestCase):
|
||||
def assertSameType(self, m1, m2):
|
||||
if not isinstance(m1, torch.jit.ScriptModule):
|
||||
@ -626,3 +621,7 @@ class TestTypeSharing(JitTestCase):
|
||||
# of A, __jit_ignored_attributes__ was modified before scripting s2,
|
||||
# so the set of ignored attributes is different between s1 and s2.
|
||||
self.assertDifferentType(s1, s2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,6 +12,7 @@ import torch
|
||||
import torch.testing._internal.jit_utils
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -19,13 +20,6 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTypesAndAnnotation(JitTestCase):
|
||||
def test_pep585_type(self):
|
||||
@ -370,3 +364,7 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "ErrorReason"):
|
||||
t = inferred_type.type()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,7 +7,7 @@ from collections import namedtuple
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -15,13 +15,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTyping(JitTestCase):
|
||||
def test_dict_in_not_in(self):
|
||||
@ -140,7 +133,7 @@ class TestTyping(JitTestCase):
|
||||
|
||||
# Check for invalid key and value type annotation
|
||||
def wrong_key_value_type(
|
||||
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]
|
||||
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule],
|
||||
):
|
||||
return
|
||||
|
||||
@ -688,3 +681,7 @@ class TestTyping(JitTestCase):
|
||||
mod2 = LowestModule()
|
||||
mod_s = torch.jit.script(mod)
|
||||
mod2_s = torch.jit.script(mod2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -15,17 +15,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestUnion(JitTestCase):
|
||||
"""
|
||||
This class tests the functionality of `Union`.
|
||||
@ -1066,3 +1059,7 @@ class TestUnion(JitTestCase):
|
||||
# "Union[Dict[str, torch.Tensor], int]",
|
||||
# lhs["dict_comprehension_of_mixed"],
|
||||
# "foobar")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -16,17 +16,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "Requires Python 3.10")
|
||||
class TestUnion(JitTestCase):
|
||||
"""
|
||||
@ -1064,3 +1057,7 @@ class TestUnion(JitTestCase):
|
||||
# "Dict[str, torch.Tensor] | int",
|
||||
# lhs["dict_comprehension_of_mixed"],
|
||||
# "foobar")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,16 +10,10 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
# NOTE: FIXING FAILING TESTS
|
||||
# If you are seeing a test failure from this file, congrats, you improved
|
||||
# parity between JIT and Python API. Before you fix the test, you must also update
|
||||
@ -90,3 +84,7 @@ class TestUnsupportedOps(JitTestCase):
|
||||
func()
|
||||
with self.assertRaisesRegex(Exception, ""):
|
||||
torch.jit.script(func)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,17 +13,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestUpgraders(JitTestCase):
|
||||
def _load_model_version(self, loaded_model):
|
||||
buffer = io.BytesIO()
|
||||
@ -346,3 +339,7 @@ class TestUpgraders(JitTestCase):
|
||||
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,17 +13,10 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestWarn(JitTestCase):
|
||||
def test_warn(self):
|
||||
@torch.jit.script
|
||||
@ -148,3 +141,7 @@ class TestWarn(JitTestCase):
|
||||
).run(
|
||||
f.getvalue()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,7 +6,10 @@ import sys
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -14,13 +17,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestWith(JitTestCase):
|
||||
"""
|
||||
@ -647,3 +643,7 @@ class TestWith(JitTestCase):
|
||||
# Nested record function should have child "aten::add"
|
||||
nested_child_events = nested_function_event.cpu_children
|
||||
self.assertTrue("aten::add" in (child.name for child in nested_child_events))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -184,3 +184,10 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
Reference in New Issue
Block a user