Add __main__ guards to jit tests (#154725)

This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In jit tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725
Approved by: https://github.com/clee2000
This commit is contained in:
Anthony Barbier
2025-06-16 10:28:45 +00:00
committed by PyTorch MergeBot
parent f810e98143
commit bf7e290854
78 changed files with 451 additions and 518 deletions

View File

@ -2,18 +2,13 @@
import torch
from torch._C import parse_ir
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.common_utils import (
raise_on_run_directly,
TemporaryFileName,
)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestAliasAnalysis(JitTestCase):
def test_becomes_wildcard_annotations(self):
graph_str = """
@ -154,3 +149,7 @@ class TestAliasAnalysis(JitTestCase):
mod = ModuleWrapper(module_list)
mod = torch.jit.script(mod)
mod(torch.zeros((2, 2)))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -16,6 +16,7 @@ from typing import List
from torch import Tensor
from torch.jit import Future
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
@ -547,8 +548,4 @@ class TestAsync(JitTestCase):
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
raise_on_run_directly("test/test_jit.py")

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: jit"]
import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
class TestAtenPow(TestCase):
@ -99,3 +99,7 @@ class TestAtenPow(TestCase):
self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
# zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -4,17 +4,10 @@ from typing import NamedTuple, Tuple
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestGetDefaultAttr(JitTestCase):
def test_getattr_with_default(self):
class A(torch.nn.Module):
@ -66,3 +59,7 @@ class TestGetDefaultAttr(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"):
torch.jit.script(fn)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -4,7 +4,10 @@
from typing import List
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
)
from torch.testing._internal.jit_utils import JitTestCase
@ -145,3 +148,7 @@ class TestAutodiffJit(JitTestCase):
self.assertEqual(x_s.requires_grad, x.requires_grad)
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -20,20 +20,13 @@ sys.path.append(pytorch_test_dir)
from typing import List, Optional, Tuple
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import (
disable_autodiff_subgraph_inlining,
JitTestCase,
)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
@unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
)
@ -589,3 +582,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
FileCheck().check("= prim::DifferentiableGraph").check(
"with prim::DifferentiableGraph"
).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
import torch
from torch import Tensor
from torch._awaits import _Await as Await
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -390,3 +391,7 @@ class TestAwait(JitTestCase):
sm = torch.jit.load(iofile)
script_out_load = sm(inp)
self.assertTrue(torch.allclose(expected, script_out_load))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -7,7 +7,11 @@ from pathlib import Path
import torch
import torch._C
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
from torch.testing._internal.common_utils import (
IS_FBCODE,
raise_on_run_directly,
skipIfTorchDynamo,
)
# hacky way to skip these tests in fbcode:
@ -28,13 +32,6 @@ else:
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
"""
Unit Tests for Nnapi backend with delegate
Inherits most tests from TestNNAPI, which loads Android NNAPI models
@ -139,3 +136,7 @@ method_compile_spec must use the following format:
def tearDown(self):
# Change dtype back to default (Otherwise, other unit tests will complain)
torch.set_default_dtype(self.default_dtype)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -15,6 +15,7 @@ from torch.testing._internal.common_utils import (
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
raise_on_run_directly,
skipIfRocm,
TEST_WITH_ROCM,
)
@ -25,13 +26,6 @@ from torch.testing._internal.jit_utils import JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
def to_test_backend(module, method_compile_spec):
return torch._C._jit_to_backend(
@ -822,3 +816,7 @@ class AddedAttributesTest(JitBackendTestCase):
)
self.assertEqual(pre_bundled, post_bundled)
self.assertEqual(post_bundled, post_load)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -2,24 +2,19 @@
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestBatchMM(JitTestCase):
@staticmethod
def _get_test_tensors(n: int):
return [
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
if x % 2 == 0
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
(
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
if x % 2 == 0
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
)
for x in range(n)
]
@ -288,3 +283,7 @@ class TestBatchMM(JitTestCase):
FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
"prim::MMBatchSide"
).run(test_batch_mm.graph)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -13,17 +13,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestBuiltins(JitTestCase):
"""
Tests for TorchScript support of Python builtin functions.
@ -299,3 +292,7 @@ class TestTensorBuiltins(JitTestCase):
self.assertEqual(
test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)
)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -18,18 +18,14 @@ sys.path.append(pytorch_test_dir)
from typing import Dict, Iterable, List, Optional, Tuple
import torch.testing._internal.jit_utils
from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo
from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
raise_on_run_directly,
skipIfTorchDynamo,
)
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestClassType(JitTestCase):
def test_reference_semantics(self):
"""
@ -1667,3 +1663,7 @@ class TestClassType(JitTestCase):
for fn in (fn_a, fn_b, fn_c, fn_d, fn_e):
with self.assertRaisesRegex(RuntimeError, error_message_regex):
torch.jit.script(fn)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -8,7 +8,7 @@ from textwrap import dedent
from typing import Dict, List
import torch
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
@ -617,3 +617,7 @@ class TestComplex(JitTestCase):
scripted = torch.jit.script(op)
jit_result = scripted(x, y)
self.assertEqual(eager_result, jit_result)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -13,7 +13,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import (
IS_FBCODE,
run_tests,
set_default_dtype,
suppress_warnings,
)
@ -105,4 +104,7 @@ class TestComplexity(JitTestCase):
if __name__ == "__main__":
run_tests()
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -22,16 +22,10 @@ skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
activations = [
F.celu,
F.elu,
@ -204,3 +198,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
inp = torch.randn(N, C, H, W)
self.run_pass("inplace_to_functional_activation", frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -12,6 +12,7 @@ from torch.testing import FileCheck
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
NoTest,
raise_on_run_directly,
skipCUDANonDefaultStreamIf,
skipIfRocm,
TEST_CUDA,
@ -36,13 +37,6 @@ if TEST_CUDA:
torch.ones(1).cuda() # initialize cuda context
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestCUDA(JitTestCase):
"""
@ -698,3 +692,7 @@ class TestCUDA(JitTestCase):
FileCheck().check("cuda::_maybe_exchange_device(").run(g)
torch._C._jit_pass_inline(g)
FileCheck().check("cuda::_maybe_exchange_device(").run(g)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
def canonical(graph):
return torch._C._jit_pass_canonicalize(graph).str(False)
@ -151,3 +144,7 @@ graph(%x.1 : Tensor):
def test_where_no_scalar(self):
x = torch.rand(1, 3, 224, 224)
torch.ops.aten.where(x > 0.5, -1.5, 1.5) # does not raise
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -12,17 +12,10 @@ import torch.nn.parallel as dp
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestDataParallel(JitTestCase):
class Mpy(torch.nn.Module):
def __init__(self) -> None:
@ -158,3 +151,7 @@ class TestDataParallel(JitTestCase):
x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
r1_forward = replica[1](x1)
self.assertEqual(first_forward, r1_forward)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -7,6 +7,7 @@ from typing import List, Optional
from hypothesis import given, settings, strategies as st
import torch
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
@ -168,3 +169,7 @@ class TestDataclasses(JitTestCase):
with self.assertRaises(OSError):
torch.jit.script(f)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -2,6 +2,7 @@
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -73,3 +74,7 @@ class TestDCE(JitTestCase):
torch._C._jit_pass_dce_graph(fn_s.graph)
FileCheck().check("aten::add_").run(fn_s.graph)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -19,3 +19,10 @@ class TestDecorator(JitTestCase):
fn = my_function_a
fx = torch.jit.script(fn)
self.assertEqual(fn(1.0), fx(1.0))
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -5,7 +5,7 @@ from itertools import product
import torch
from torch.jit._passes._property_propagation import apply_input_props_using_example
from torch.testing._internal.common_utils import TEST_CUDA
from torch.testing._internal.common_utils import raise_on_run_directly, TEST_CUDA
from torch.testing._internal.jit_utils import JitTestCase
@ -14,13 +14,6 @@ try:
except ImportError:
models = None
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestDeviceAnalysis(JitTestCase):
@classmethod
@ -336,3 +329,7 @@ class TestDeviceAnalysis(JitTestCase):
test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn
)
self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -17,7 +17,11 @@ from torch.testing._internal.common_methods_invocations import (
sample_inputs_conv2d,
SampleInput,
)
from torch.testing._internal.common_utils import first_sample, set_default_dtype
from torch.testing._internal.common_utils import (
first_sample,
raise_on_run_directly,
set_default_dtype,
)
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
from torch.testing._internal.jit_utils import JitTestCase
@ -27,14 +31,6 @@ Dtype Analysis relies on symbolic shape analysis, which is still in beta
"""
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
custom_rules_works_list = {
"nn.functional.adaptive_avg_pool1d",
"nn.functional.adaptive_avg_pool2d",
@ -386,3 +382,6 @@ class TestDtypeCustomRules(TestDtypeBase):
TestDtypeCustomRulesCPU = None
# This creates TestDtypeCustomRulesCPU
instantiate_device_type_tests(TestDtypeCustomRules, globals(), only_for=("cpu",))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -12,17 +12,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestEnum(JitTestCase):
def test_enum_value_types(self):
class IntEnum(Enum):
@ -358,3 +351,7 @@ class TestEnum(JitTestCase):
@torch.jit.script
def is_red(x: Color) -> bool:
return x == Color.RED
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -197,3 +197,10 @@ class TestException(TestCase):
"jit.myexception.MyKeyError: This is a user defined key error",
):
fn()
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)

View File

@ -15,6 +15,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import (
raise_on_run_directly,
set_default_dtype,
skipCUDAMemoryLeakCheckIf,
skipIfTorchDynamo,
@ -32,13 +33,6 @@ except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None
@ -55,7 +49,7 @@ class TestFreezing(JitTestCase):
self.a = 1 # folded
self.b = 1.2 # folded
self.c = "hello" # folded
self.c2 = "hi\xA1" # not folded
self.c2 = "hi\xa1" # not folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
@ -67,7 +61,7 @@ class TestFreezing(JitTestCase):
torch.tensor([5.5], requires_grad=True),
) # folded
self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]}
self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]}
self.h2 = {"layer\xb1": [torch.tensor([8.8], requires_grad=True)]}
self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded
self.ts = [
torch.tensor([1.0, 2.0], requires_grad=True),
@ -3461,3 +3455,7 @@ class TestMKLDNNReinplacing(JitTestCase):
mod = self.freezeAndConvert(mod_eager)
FileCheck().check("aten::add_").run(mod.graph)
self.checkResults(mod_eager, mod)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestFunctionalBlocks(JitTestCase):
def test_subgraph_creation(self):
def fn(x, y, z):
@ -54,3 +47,7 @@ class TestFunctionalBlocks(JitTestCase):
FileCheck().check("add").check("add_").check_not("mul").check(
"FunctionalGraph"
).run(graph)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: jit"]
import torch
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
@ -19,3 +20,7 @@ class TestFuserCommon(JitTestCase):
# test fallback when optimization is not applicable
y = fn(torch.randn(5, requires_grad=rq))
self.assertEqual(y.requires_grad, rq)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit_fuser_te.py")

View File

@ -6,18 +6,13 @@ import unittest
import torch
from torch.nn import init
from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfLegacyJitExecutor,
)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestGenerator(JitTestCase):
# torch.jit.trace does not properly capture the generator manual seed
# and thus is non deterministic even if the generator is manually seeded
@ -193,3 +188,7 @@ class TestGenerator(JitTestCase):
except: # noqa: B001, E722
print(loaded_module.forward.code)
raise
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -3,6 +3,7 @@
import torch
import torch._C
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
@ -59,3 +60,7 @@ class TestGraphRewritePasses(JitTestCase):
FileCheck().check_not("aten::linear").run(model.graph)
# make sure it runs
model(x)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestHash(JitTestCase):
def test_hash_tuple(self):
def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
@ -115,3 +108,7 @@ class TestHash(JitTestCase):
self.checkScript(fn, (gpu0, gpu1))
self.checkScript(fn, (gpu0, cpu))
self.checkScript(fn, (cpu, cpu))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -33,17 +33,10 @@ from jit.test_hooks_modules import (
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests for JIT forward hooks and pre-hooks
class TestHooks(JitTestCase):
def test_module_no_forward_input(self):
@ -393,3 +386,7 @@ class TestHooks(JitTestCase):
r"Received type: 'str'. Expected type: 'Tuple\[str\]'",
):
torch.jit.script(m)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -528,3 +528,9 @@ def create_submodule_forward_single_input_return_not_tupled():
m.submodule.register_forward_hook(forward_hook)
return m
if __name__ == "__main__":
raise RuntimeError(
"This file is a collection of utils, it should be imported not executed directly"
)

View File

@ -11,17 +11,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests that Python slice class is supported in TorchScript
class TestIgnorableArgs(JitTestCase):
def test_slice_ignorable_args_for_slice(self):
@ -61,3 +54,7 @@ class TestIgnorableArgs(JitTestCase):
torch.add(x, y, out=y)
FileCheck().check("torch.add(x, y, out=y)").run(fn.code)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -11,17 +11,10 @@ import torch
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestIgnoreContextManager(JitTestCase):
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
def test_with_ignore_context_manager_with_inp_out(self):
@ -103,3 +96,7 @@ class TestIgnoreContextManager(JitTestCase):
s = torch.jit.script(model)
self.assertEqual(s(), 5)
self.assertEqual(s(), model())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -11,17 +11,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests for torch.jit.isinstance
class TestIsinstance(JitTestCase):
def test_int(self):
@ -354,3 +347,7 @@ class TestIsinstance(JitTestCase):
# Should not throw "Boolean value of Tensor with more than
# one value is ambiguous" error
torch._jit_internal.check_empty_containers(torch.rand(2, 3))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -11,17 +11,10 @@ from torch.testing._internal import jit_utils
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests various JIT-related utility functions.
class TestJitUtils(JitTestCase):
# Tests that POSITIONAL_OR_KEYWORD arguments are captured.
@ -116,3 +109,7 @@ class TestJitUtils(JitTestCase):
with jit_utils.NoTracerWarnContextManager():
self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -19,18 +19,14 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import skipIfTorchDynamo, TEST_CUDA
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
TEST_CUDA,
)
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestList(JitTestCase):
def test_list_bool_conversion(self):
def if_predicate(l: List[int]):
@ -1825,7 +1821,7 @@ class TestDict(JitTestCase):
def test_popitem(self):
@torch.jit.script
def popitem(
x: Dict[str, Tensor]
x: Dict[str, Tensor],
) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]:
item = x.popitem()
return item, x
@ -2996,3 +2992,7 @@ class TestScriptList(JitTestCase):
for i in range(300):
test = Test()
test_script = torch.jit.script(test)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestLogging(JitTestCase):
def test_bump_numeric_counter(self):
class ModuleThatLogs(torch.jit.ScriptModule):
@ -122,3 +115,7 @@ class TestLogging(JitTestCase):
def test_logging_levels_set(self):
torch._C._jit_set_logging_option("foo")
self.assertEqual("foo", torch._C._jit_get_logging_option())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -12,7 +12,7 @@ import torch.testing._internal.jit_utils
from jit.test_module_interface import TestModuleInterface # noqa: F401
from torch import jit
from torch.testing import FileCheck
from torch.testing._internal.common_utils import freeze_rng_state
from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
@ -20,13 +20,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestMisc(JitTestCase):
def test_joined_str(self):
@ -129,7 +122,7 @@ class TestMisc(JitTestCase):
def test_subexpression_Tuple_int_int_Future(self):
@torch.jit.script
def fn(
x: Tuple[int, int, torch.jit.Future[int]]
x: Tuple[int, int, torch.jit.Future[int]],
) -> Tuple[int, torch.jit.Future[int]]:
return x[0], x[2]
@ -147,7 +140,7 @@ class TestMisc(JitTestCase):
def test_subexpression_Optional(self):
@torch.jit.script
def fn(
x: Optional[Dict[int, torch.jit.Future[int]]]
x: Optional[Dict[int, torch.jit.Future[int]]],
) -> Optional[torch.jit.Future[int]]:
if x is not None:
return x[0]
@ -504,3 +497,7 @@ class TestMisc(JitTestCase):
self.assertTrue(len(complex_indices) > 0)
self.assertTrue(len(Scalar_indices) > 0)
self.assertTrue(complex_indices[0] > Scalar_indices[0])
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -11,24 +11,19 @@ from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
GRAPH_EXECUTOR,
ProfilingMode,
raise_on_run_directly,
set_default_dtype,
slowTest,
suppress_warnings,
)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
try:
import torchvision
@ -84,7 +79,7 @@ class TestModels(JitTestCase):
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
nn.Tanh(),
# state size. (nc) x 64 x 64
)
@ -754,3 +749,7 @@ class TestModels(JitTestCase):
m = self.createFunctionFromGraph(g)
with torch.random.fork_rng(devices=[]):
self.assertEqual(outputs, m(*inputs))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -5,6 +5,7 @@ import sys
from typing import Any, Dict, List
import torch
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
@ -12,13 +13,6 @@ from torch.testing._internal.jit_utils import JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModuleAPIs(JitTestCase):
def test_default_state_dict_methods(self):
@ -141,3 +135,7 @@ class TestModuleAPIs(JitTestCase):
self.assertFalse(m2.sub.customized_load_state_dict_called)
m2.load_state_dict(state_dict)
self.assertTrue(m2.sub.customized_load_state_dict_called)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -7,6 +7,7 @@ from typing import Any, List, Tuple
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
@ -14,13 +15,6 @@ from torch.testing._internal.jit_utils import JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModuleContainers(JitTestCase):
def test_sequential_intermediary_types(self):
@ -756,3 +750,7 @@ class TestModuleContainers(JitTestCase):
)
self.checkModule(MyModule(), (torch.ones(1),))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -8,6 +8,7 @@ from typing import Any, List
import torch
import torch.nn as nn
from torch import Tensor
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -15,13 +16,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class OrigModule(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
@ -701,3 +695,7 @@ class TestModuleInterface(JitTestCase):
with self.assertRaisesRegex(Exception, "Could not compile"):
scripted_mod = torch.jit.script(TestModule())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -4,6 +4,7 @@ import os
import sys
import torch
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
@ -11,13 +12,6 @@ from torch.testing._internal.jit_utils import JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModules(JitTestCase):
def test_script_module_with_constants_list(self):
@ -36,3 +30,7 @@ class TestModules(JitTestCase):
self.x = 0
self.checkModule(Net(), (torch.randn(5),))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -2,17 +2,10 @@
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestOpDecompositions(JitTestCase):
def test_op_decomposition(self):
def foo(x):
@ -42,3 +35,7 @@ class TestOpDecompositions(JitTestCase):
FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph)
x = torch.rand([4])
self.assertEqual(foo(x), torch.square(x))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -3,7 +3,7 @@
import torch
import torch._C
import torch.nn.functional as F
from torch.testing._internal.common_utils import skipIfNoXNNPACK
from torch.testing._internal.common_utils import raise_on_run_directly, skipIfNoXNNPACK
from torch.testing._internal.jit_utils import JitTestCase
@ -263,3 +263,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
conv2d_activation=F.relu,
conv2d_activation_kind="aten::relu",
)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -4,17 +4,10 @@
import torch
import torch.nn.utils.parametrize as parametrize
from torch import nn
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestParametrization(JitTestCase):
# Define some parametrization
class Symmetric(nn.Module):
@ -68,3 +61,7 @@ class TestParametrization(JitTestCase):
# Check the scripting process throws an error when caching
with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"):
scripted_model = torch.jit.trace_module(model)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401
import torch
from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
from torch.testing._internal.common_utils import NoTest
from torch.testing._internal.common_utils import NoTest, raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -21,13 +21,6 @@ if not _IS_MONKEYTYPE_INSTALLED:
)
JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestPDT(JitTestCase):
"""
@ -896,3 +889,7 @@ class TestPDT(JitTestCase):
torch.ones(1),
),
)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -6,17 +6,10 @@ from typing import Callable, List
import torch
from torch import nn
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestPeephole(JitTestCase):
def test_peephole_with_writes(self):
def test_write(x):
@ -890,3 +883,7 @@ class TestPeephole(JitTestCase):
self.run_pass("peephole", foo.graph)
FileCheck().check("aten::slice").run(foo.graph)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -4,7 +4,10 @@ import os
import sys
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
)
# Make the helper files in test/ importable
@ -13,14 +16,6 @@ sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
@skipIfTorchDynamo()
class TestProfiler(JitTestCase):
def setUp(self):
@ -284,3 +279,7 @@ class TestProfiler(JitTestCase):
g = torch.jit.last_executed_optimized_graph()
self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -2,17 +2,10 @@
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TestPythonBindings\n\n"
"instead."
)
class TestPythonBindings(JitTestCase):
def test_cu_get_functions(self):
@torch.jit.script
@ -114,3 +107,7 @@ graph(%p207 : Tensor,
graph3 = torch._C.parse_ir(ir)
graph3 = torch._C._jit_pass_canonicalize(graph3, False)
FileCheck().check_not("%p207").run(graph3)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -7,6 +7,7 @@ import tempfile
from textwrap import dedent
import torch
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
@ -14,13 +15,6 @@ from torch.testing._internal.jit_utils import execWrapper, JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
def get_fn(file_name, script_path):
import importlib.util
@ -473,3 +467,7 @@ class TestPythonBuiltinOP(JitTestCase):
s = torch.rand(1)
self.assertTrue(foo(s))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -6,18 +6,10 @@ import numpy as np
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestPythonIr(JitTestCase):
def test_param_strides(self):
def trace_me(arg):
@ -100,3 +92,7 @@ class TestPythonIr(JitTestCase):
FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -20,20 +20,13 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import (
_tmp_donotuse_dont_inline_everything,
JitTestCase,
)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestRecursiveScript(JitTestCase):
def test_inferred_nonetype(self):
class M(nn.Module):
@ -799,3 +792,7 @@ class TestRecursiveScript(JitTestCase):
# ScriptModule should correctly reflect the override.
s = torch.jit.script(m)
self.assertEqual(s.i_am_ignored(), "new")
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -11,17 +11,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestRemoveMutation(JitTestCase):
def test_aten_inplace(self):
def test_not_new_alias(x):
@ -318,3 +311,7 @@ class TestRemoveMutation(JitTestCase):
self.run_pass("remove_mutation", mod_script.forward.graph)
FileCheck().check("aten::add_").run(test_multiple_uses.graph)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -8,7 +8,11 @@ from typing import NamedTuple, Optional
import torch
from torch import Tensor
from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
TemporaryFileName,
)
# Make the helper files in test/ importable
@ -17,14 +21,6 @@ sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestSaveLoad(JitTestCase):
def test_different_modules(self):
"""
@ -1197,3 +1193,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)
self.assertEqual(extra_files, re_extra_files)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -17,17 +17,10 @@ import torch
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestSaveLoadForOpVersion(JitTestCase):
# Helper that returns the module after saving and loading
def _save_load_module(self, m):
@ -617,3 +610,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertTrue(output.size(dim=0) == 100)
# "Upgraded" model should match the new version output
self.assertEqual(output, output_current)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ from torch import nn
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class Sequence(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -115,3 +108,7 @@ class TestScriptProfile(JitTestCase):
p.enable()
p.disable()
self.assertEqual(p.dump_string(), "")
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -11,17 +11,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
# NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
# reassigning a non-empty Tuple to an attribute previously typed
@ -363,3 +356,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
"empty non-base types",
):
torch.jit.script(M())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests that Python slice class is supported in TorchScript
class TestSlice(JitTestCase):
def test_slice_kwarg(self):
@ -178,3 +171,7 @@ class TestSlice(JitTestCase):
self.assertEqual(result2[0].identifier, "B")
self.assertEqual(result2[1].identifier, "C")
self.assertEqual(result2[2].identifier, "D")
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -4,7 +4,11 @@ import io
import unittest
import torch
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
from torch.testing._internal.common_utils import (
IS_WINDOWS,
raise_on_run_directly,
TEST_MKL,
)
from torch.testing._internal.jit_utils import JitTestCase
@ -118,3 +122,7 @@ class TestSparse(JitTestCase):
loaded_result = loaded_model.forward(x)
self.assertEqual(expected_result.to_dense(), loaded_result.to_dense())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestStringFormatting(JitTestCase):
def test_modulo_operator(self):
def fn(dividend: int, divisor: int) -> int:
@ -199,3 +192,7 @@ class TestStringFormatting(JitTestCase):
'"%a in template" % arg1',
):
fn("foo")
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -9,18 +9,10 @@ import torch
from torch import nn, Tensor
from torch.testing import FileCheck
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
from torch.testing._internal.common_utils import make_tensor
from torch.testing._internal.common_utils import make_tensor, raise_on_run_directly
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# XXX: still in prototype
class TestSymbolicShapeAnalysis(JitTestCase):
def setUp(self):
@ -819,3 +811,7 @@ class TestSymbolicShapeAnalysis(JitTestCase):
input.setType(input.type().with_sizes([1, 5, 8]))
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
self.assertEqual(next(foo.graph.outputs()).type().symbolic_sizes(), [5, 8])
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -9,17 +9,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTensorCreationOps(JitTestCase):
"""
A suite of tests for ops that create tensors.
@ -78,3 +71,7 @@ class TestTensorCreationOps(JitTestCase):
assert indices.dtype == torch.int32
self.checkScript(tril_indices, (3, 3))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,17 +10,10 @@ import torch
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTensorMethods(JitTestCase):
def test_getitem(self):
def tensor_getitem(inp: torch.Tensor):
@ -41,3 +34,7 @@ class TestTensorMethods(JitTestCase):
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"
):
torch.jit.script(tensor_getitem_invalid)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -8,7 +8,10 @@ import sys
from typing import Optional
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
)
# Make the helper files in test/ importable
@ -19,14 +22,6 @@ from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
@skipIfTorchDynamo("skipping as a precaution")
class TestTorchbind(JitTestCase):
def setUp(self):
@ -463,3 +458,7 @@ class TestTorchbind(JitTestCase):
return obj.decrement()
self.checkScript(gn, ())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -29,6 +29,7 @@ from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
IS_SANDCASTLE,
raise_on_run_directly,
skipIfCompiledWithoutNumpy,
skipIfCrossRef,
skipIfTorchDynamo,
@ -46,14 +47,6 @@ from torch.testing._internal.jit_utils import (
)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestTracer(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@ -2826,3 +2819,7 @@ class TestMixTracingScripting(JitTestCase):
for n in fn_t.graph.nodes():
if n.kind() == "prim::CallFunction":
self.assertTrue(n.output().isCompleteTensor())
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,18 +10,13 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.common_utils import (
raise_on_run_directly,
suppress_warnings,
)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTypeSharing(JitTestCase):
def assertSameType(self, m1, m2):
if not isinstance(m1, torch.jit.ScriptModule):
@ -626,3 +621,7 @@ class TestTypeSharing(JitTestCase):
# of A, __jit_ignored_attributes__ was modified before scripting s2,
# so the set of ignored attributes is different between s1 and s2.
self.assertDifferentType(s1, s2)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -12,6 +12,7 @@ import torch
import torch.testing._internal.jit_utils
from jit.test_module_interface import TestModuleInterface # noqa: F401
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
@ -19,13 +20,6 @@ from torch.testing._internal.jit_utils import JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTypesAndAnnotation(JitTestCase):
def test_pep585_type(self):
@ -370,3 +364,7 @@ class TestTypesAndAnnotation(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "ErrorReason"):
t = inferred_type.type()
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -7,7 +7,7 @@ from collections import namedtuple
from typing import Dict, List, NamedTuple, Tuple
import torch
from torch.testing._internal.common_utils import IS_WINDOWS
from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -15,13 +15,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTyping(JitTestCase):
def test_dict_in_not_in(self):
@ -140,7 +133,7 @@ class TestTyping(JitTestCase):
# Check for invalid key and value type annotation
def wrong_key_value_type(
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule],
):
return
@ -688,3 +681,7 @@ class TestTyping(JitTestCase):
mod2 = LowestModule()
mod_s = torch.jit.script(mod)
mod2_s = torch.jit.script(mod2)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -15,17 +15,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestUnion(JitTestCase):
"""
This class tests the functionality of `Union`.
@ -1066,3 +1059,7 @@ class TestUnion(JitTestCase):
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_mixed"],
# "foobar")
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -16,17 +16,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
@unittest.skipIf(sys.version_info < (3, 10), "Requires Python 3.10")
class TestUnion(JitTestCase):
"""
@ -1064,3 +1057,7 @@ class TestUnion(JitTestCase):
# "Dict[str, torch.Tensor] | int",
# lhs["dict_comprehension_of_mixed"],
# "foobar")
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -10,16 +10,10 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# NOTE: FIXING FAILING TESTS
# If you are seeing a test failure from this file, congrats, you improved
# parity between JIT and Python API. Before you fix the test, you must also update
@ -90,3 +84,7 @@ class TestUnsupportedOps(JitTestCase):
func()
with self.assertRaisesRegex(Exception, ""):
torch.jit.script(func)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -13,17 +13,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestUpgraders(JitTestCase):
def _load_model_version(self, loaded_model):
buffer = io.BytesIO()
@ -346,3 +339,7 @@ class TestUpgraders(JitTestCase):
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
version = self._load_model_version(loaded_model)
self.assertTrue(version == 5)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -13,17 +13,10 @@ from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestWarn(JitTestCase):
def test_warn(self):
@torch.jit.script
@ -148,3 +141,7 @@ class TestWarn(JitTestCase):
).run(
f.getvalue()
)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -6,7 +6,10 @@ import sys
from typing import Any, List
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.common_utils import (
raise_on_run_directly,
skipIfTorchDynamo,
)
from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -14,13 +17,6 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestWith(JitTestCase):
"""
@ -647,3 +643,7 @@ class TestWith(JitTestCase):
# Nested record function should have child "aten::add"
nested_child_events = nested_function_event.cpu_children
self.assertTrue("aten::add" in (child.name for child in nested_child_events))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")

View File

@ -184,3 +184,10 @@ class TestXNNPackBackend(unittest.TestCase):
}
},
)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)