mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add __main__ guards to jit tests (#154725)"
This reverts commit 1a55fb0ee87eaa8b376aaa82d95d213fe0fbe64b. Reverted https://github.com/pytorch/pytorch/pull/154725 on behalf of https://github.com/malfet due to This added 2nd copy of raise_on_run to common_utils.py which caused lint failures, see https://github.com/pytorch/pytorch/actions/runs/15445374980/job/43473457466 ([comment](https://github.com/pytorch/pytorch/pull/154725#issuecomment-2940503905))
This commit is contained in:
@ -2,13 +2,18 @@
|
||||
|
||||
import torch
|
||||
from torch._C import parse_ir
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
TemporaryFileName,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestAliasAnalysis(JitTestCase):
|
||||
def test_becomes_wildcard_annotations(self):
|
||||
graph_str = """
|
||||
@ -149,7 +154,3 @@ class TestAliasAnalysis(JitTestCase):
|
||||
mod = ModuleWrapper(module_list)
|
||||
mod = torch.jit.script(mod)
|
||||
mod(torch.zeros((2, 2)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -16,7 +16,6 @@ from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.jit import Future
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
|
||||
|
||||
|
||||
@ -548,4 +547,8 @@ class TestAsync(JitTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class TestAtenPow(TestCase):
|
||||
@ -99,7 +99,3 @@ class TestAtenPow(TestCase):
|
||||
self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
|
||||
# zero base and negative exponent case that should trigger RunTimeError
|
||||
self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,10 +4,17 @@ from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultAttr(JitTestCase):
|
||||
def test_getattr_with_default(self):
|
||||
class A(torch.nn.Module):
|
||||
@ -59,7 +66,3 @@ class TestGetDefaultAttr(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,10 +4,7 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -148,7 +145,3 @@ class TestAutodiffJit(JitTestCase):
|
||||
self.assertEqual(x_s.requires_grad, x.requires_grad)
|
||||
self.assertEqual(y_s.requires_grad, y.requires_grad)
|
||||
self.assertEqual(z_s.requires_grad, z.requires_grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -20,13 +20,20 @@ sys.path.append(pytorch_test_dir)
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import (
|
||||
disable_autodiff_subgraph_inlining,
|
||||
JitTestCase,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
||||
)
|
||||
@ -582,7 +589,3 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
FileCheck().check("= prim::DifferentiableGraph").check(
|
||||
"with prim::DifferentiableGraph"
|
||||
).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,7 +6,6 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._awaits import _Await as Await
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -391,7 +390,3 @@ class TestAwait(JitTestCase):
|
||||
sm = torch.jit.load(iofile)
|
||||
script_out_load = sm(inp)
|
||||
self.assertTrue(torch.allclose(expected, script_out_load))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,11 +7,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
|
||||
|
||||
|
||||
# hacky way to skip these tests in fbcode:
|
||||
@ -32,6 +28,13 @@ else:
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
"""
|
||||
Unit Tests for Nnapi backend with delegate
|
||||
Inherits most tests from TestNNAPI, which loads Android NNAPI models
|
||||
@ -136,7 +139,3 @@ method_compile_spec must use the following format:
|
||||
def tearDown(self):
|
||||
# Change dtype back to default (Otherwise, other unit tests will complain)
|
||||
torch.set_default_dtype(self.default_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -15,7 +15,6 @@ from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
skipIfRocm,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
@ -26,6 +25,13 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def to_test_backend(module, method_compile_spec):
|
||||
return torch._C._jit_to_backend(
|
||||
@ -816,7 +822,3 @@ class AddedAttributesTest(JitBackendTestCase):
|
||||
)
|
||||
self.assertEqual(pre_bundled, post_bundled)
|
||||
self.assertEqual(post_bundled, post_load)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,19 +2,24 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestBatchMM(JitTestCase):
|
||||
@staticmethod
|
||||
def _get_test_tensors(n: int):
|
||||
return [
|
||||
(
|
||||
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
|
||||
if x % 2 == 0
|
||||
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
|
||||
)
|
||||
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
|
||||
if x % 2 == 0
|
||||
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
|
||||
for x in range(n)
|
||||
]
|
||||
|
||||
@ -283,7 +288,3 @@ class TestBatchMM(JitTestCase):
|
||||
FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
|
||||
"prim::MMBatchSide"
|
||||
).run(test_batch_mm.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,10 +13,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestBuiltins(JitTestCase):
|
||||
"""
|
||||
Tests for TorchScript support of Python builtin functions.
|
||||
@ -292,7 +299,3 @@ class TestTensorBuiltins(JitTestCase):
|
||||
self.assertEqual(
|
||||
test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -18,14 +18,18 @@ sys.path.append(pytorch_test_dir)
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch.testing._internal.jit_utils
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_SANDCASTLE,
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestClassType(JitTestCase):
|
||||
def test_reference_semantics(self):
|
||||
"""
|
||||
@ -1663,7 +1667,3 @@ class TestClassType(JitTestCase):
|
||||
for fn in (fn_a, fn_b, fn_c, fn_d, fn_e):
|
||||
with self.assertRaisesRegex(RuntimeError, error_message_regex):
|
||||
torch.jit.script(fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,7 +8,7 @@ from textwrap import dedent
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
|
||||
@ -617,7 +617,3 @@ class TestComplex(JitTestCase):
|
||||
scripted = torch.jit.script(op)
|
||||
jit_result = scripted(x, y)
|
||||
self.assertEqual(eager_result, jit_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,6 +13,7 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
run_tests,
|
||||
set_default_dtype,
|
||||
suppress_warnings,
|
||||
)
|
||||
@ -104,7 +105,4 @@ class TestComplexity(JitTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
run_tests()
|
||||
|
@ -22,10 +22,16 @@ skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
activations = [
|
||||
F.celu,
|
||||
F.elu,
|
||||
@ -198,7 +204,3 @@ class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
inp = torch.randn(N, C, H, W)
|
||||
self.run_pass("inplace_to_functional_activation", frozen_model.graph)
|
||||
self.assertEqual(model(inp), frozen_model(inp))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,7 +12,6 @@ from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import (
|
||||
NoTest,
|
||||
raise_on_run_directly,
|
||||
skipCUDANonDefaultStreamIf,
|
||||
skipIfRocm,
|
||||
TEST_CUDA,
|
||||
@ -37,6 +36,13 @@ if TEST_CUDA:
|
||||
torch.ones(1).cuda() # initialize cuda context
|
||||
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestCUDA(JitTestCase):
|
||||
"""
|
||||
@ -692,7 +698,3 @@ class TestCUDA(JitTestCase):
|
||||
FileCheck().check("cuda::_maybe_exchange_device(").run(g)
|
||||
torch._C._jit_pass_inline(g)
|
||||
FileCheck().check("cuda::_maybe_exchange_device(").run(g)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def canonical(graph):
|
||||
return torch._C._jit_pass_canonicalize(graph).str(False)
|
||||
|
||||
@ -144,7 +151,3 @@ graph(%x.1 : Tensor):
|
||||
def test_where_no_scalar(self):
|
||||
x = torch.rand(1, 3, 224, 224)
|
||||
torch.ops.aten.where(x > 0.5, -1.5, 1.5) # does not raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,10 +12,17 @@ import torch.nn.parallel as dp
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestDataParallel(JitTestCase):
|
||||
class Mpy(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -151,7 +158,3 @@ class TestDataParallel(JitTestCase):
|
||||
x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
|
||||
r1_forward = replica[1](x1)
|
||||
self.assertEqual(first_forward, r1_forward)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,7 +7,6 @@ from typing import List, Optional
|
||||
from hypothesis import given, settings, strategies as st
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -169,7 +168,3 @@ class TestDataclasses(JitTestCase):
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
torch.jit.script(f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -74,7 +73,3 @@ class TestDCE(JitTestCase):
|
||||
torch._C._jit_pass_dce_graph(fn_s.graph)
|
||||
|
||||
FileCheck().check("aten::add_").run(fn_s.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -19,10 +19,3 @@ class TestDecorator(JitTestCase):
|
||||
fn = my_function_a
|
||||
fx = torch.jit.script(fn)
|
||||
self.assertEqual(fn(1.0), fx(1.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ from itertools import product
|
||||
|
||||
import torch
|
||||
from torch.jit._passes._property_propagation import apply_input_props_using_example
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, TEST_CUDA
|
||||
from torch.testing._internal.common_utils import TEST_CUDA
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -14,6 +14,13 @@ try:
|
||||
except ImportError:
|
||||
models = None
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceAnalysis(JitTestCase):
|
||||
@classmethod
|
||||
@ -329,7 +336,3 @@ class TestDeviceAnalysis(JitTestCase):
|
||||
test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn
|
||||
)
|
||||
self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -17,11 +17,7 @@ from torch.testing._internal.common_methods_invocations import (
|
||||
sample_inputs_conv2d,
|
||||
SampleInput,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
first_sample,
|
||||
raise_on_run_directly,
|
||||
set_default_dtype,
|
||||
)
|
||||
from torch.testing._internal.common_utils import first_sample, set_default_dtype
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
@ -31,6 +27,14 @@ Dtype Analysis relies on symbolic shape analysis, which is still in beta
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
custom_rules_works_list = {
|
||||
"nn.functional.adaptive_avg_pool1d",
|
||||
"nn.functional.adaptive_avg_pool2d",
|
||||
@ -382,6 +386,3 @@ class TestDtypeCustomRules(TestDtypeBase):
|
||||
TestDtypeCustomRulesCPU = None
|
||||
# This creates TestDtypeCustomRulesCPU
|
||||
instantiate_device_type_tests(TestDtypeCustomRules, globals(), only_for=("cpu",))
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,10 +12,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestEnum(JitTestCase):
|
||||
def test_enum_value_types(self):
|
||||
class IntEnum(Enum):
|
||||
@ -351,7 +358,3 @@ class TestEnum(JitTestCase):
|
||||
@torch.jit.script
|
||||
def is_red(x: Color) -> bool:
|
||||
return x == Color.RED
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -197,10 +197,3 @@ class TestException(TestCase):
|
||||
"jit.myexception.MyKeyError: This is a user defined key error",
|
||||
):
|
||||
fn()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -15,7 +15,6 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
set_default_dtype,
|
||||
skipCUDAMemoryLeakCheckIf,
|
||||
skipIfTorchDynamo,
|
||||
@ -33,6 +32,13 @@ except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None
|
||||
|
||||
|
||||
@ -49,7 +55,7 @@ class TestFreezing(JitTestCase):
|
||||
self.a = 1 # folded
|
||||
self.b = 1.2 # folded
|
||||
self.c = "hello" # folded
|
||||
self.c2 = "hi\xa1" # not folded
|
||||
self.c2 = "hi\xA1" # not folded
|
||||
self.d = [1, 1] # folded
|
||||
self.e = [1.0, 1.1] # folded
|
||||
self.f = ["hello", "world"] # folded
|
||||
@ -61,7 +67,7 @@ class TestFreezing(JitTestCase):
|
||||
torch.tensor([5.5], requires_grad=True),
|
||||
) # folded
|
||||
self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]}
|
||||
self.h2 = {"layer\xb1": [torch.tensor([8.8], requires_grad=True)]}
|
||||
self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]}
|
||||
self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded
|
||||
self.ts = [
|
||||
torch.tensor([1.0, 2.0], requires_grad=True),
|
||||
@ -3455,7 +3461,3 @@ class TestMKLDNNReinplacing(JitTestCase):
|
||||
mod = self.freezeAndConvert(mod_eager)
|
||||
FileCheck().check("aten::add_").run(mod.graph)
|
||||
self.checkResults(mod_eager, mod)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestFunctionalBlocks(JitTestCase):
|
||||
def test_subgraph_creation(self):
|
||||
def fn(x, y, z):
|
||||
@ -47,7 +54,3 @@ class TestFunctionalBlocks(JitTestCase):
|
||||
FileCheck().check("add").check("add_").check_not("mul").check(
|
||||
"FunctionalGraph"
|
||||
).run(graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -20,7 +19,3 @@ class TestFuserCommon(JitTestCase):
|
||||
# test fallback when optimization is not applicable
|
||||
y = fn(torch.randn(5, requires_grad=rq))
|
||||
self.assertEqual(y.requires_grad, rq)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit_fuser_te.py")
|
||||
|
@ -6,13 +6,18 @@ import unittest
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfLegacyJitExecutor,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestGenerator(JitTestCase):
|
||||
# torch.jit.trace does not properly capture the generator manual seed
|
||||
# and thus is non deterministic even if the generator is manually seeded
|
||||
@ -188,7 +193,3 @@ class TestGenerator(JitTestCase):
|
||||
except: # noqa: B001, E722
|
||||
print(loaded_module.forward.code)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -3,7 +3,6 @@
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -60,7 +59,3 @@ class TestGraphRewritePasses(JitTestCase):
|
||||
FileCheck().check_not("aten::linear").run(model.graph)
|
||||
# make sure it runs
|
||||
model(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestHash(JitTestCase):
|
||||
def test_hash_tuple(self):
|
||||
def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
|
||||
@ -108,7 +115,3 @@ class TestHash(JitTestCase):
|
||||
self.checkScript(fn, (gpu0, gpu1))
|
||||
self.checkScript(fn, (gpu0, cpu))
|
||||
self.checkScript(fn, (cpu, cpu))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -33,10 +33,17 @@ from jit.test_hooks_modules import (
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests for JIT forward hooks and pre-hooks
|
||||
class TestHooks(JitTestCase):
|
||||
def test_module_no_forward_input(self):
|
||||
@ -386,7 +393,3 @@ class TestHooks(JitTestCase):
|
||||
r"Received type: 'str'. Expected type: 'Tuple\[str\]'",
|
||||
):
|
||||
torch.jit.script(m)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -528,9 +528,3 @@ def create_submodule_forward_single_input_return_not_tupled():
|
||||
m.submodule.register_forward_hook(forward_hook)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This file is a collection of utils, it should be imported not executed directly"
|
||||
)
|
||||
|
@ -11,10 +11,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests that Python slice class is supported in TorchScript
|
||||
class TestIgnorableArgs(JitTestCase):
|
||||
def test_slice_ignorable_args_for_slice(self):
|
||||
@ -54,7 +61,3 @@ class TestIgnorableArgs(JitTestCase):
|
||||
torch.add(x, y, out=y)
|
||||
|
||||
FileCheck().check("torch.add(x, y, out=y)").run(fn.code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,10 +11,17 @@ import torch
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestIgnoreContextManager(JitTestCase):
|
||||
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
||||
def test_with_ignore_context_manager_with_inp_out(self):
|
||||
@ -96,7 +103,3 @@ class TestIgnoreContextManager(JitTestCase):
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 5)
|
||||
self.assertEqual(s(), model())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,10 +11,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests for torch.jit.isinstance
|
||||
class TestIsinstance(JitTestCase):
|
||||
def test_int(self):
|
||||
@ -347,7 +354,3 @@ class TestIsinstance(JitTestCase):
|
||||
# Should not throw "Boolean value of Tensor with more than
|
||||
# one value is ambiguous" error
|
||||
torch._jit_internal.check_empty_containers(torch.rand(2, 3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,10 +11,17 @@ from torch.testing._internal import jit_utils
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests various JIT-related utility functions.
|
||||
class TestJitUtils(JitTestCase):
|
||||
# Tests that POSITIONAL_OR_KEYWORD arguments are captured.
|
||||
@ -109,7 +116,3 @@ class TestJitUtils(JitTestCase):
|
||||
with jit_utils.NoTracerWarnContextManager():
|
||||
self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
|
||||
self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -19,14 +19,18 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TEST_CUDA
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestList(JitTestCase):
|
||||
def test_list_bool_conversion(self):
|
||||
def if_predicate(l: List[int]):
|
||||
@ -1821,7 +1825,7 @@ class TestDict(JitTestCase):
|
||||
def test_popitem(self):
|
||||
@torch.jit.script
|
||||
def popitem(
|
||||
x: Dict[str, Tensor],
|
||||
x: Dict[str, Tensor]
|
||||
) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]:
|
||||
item = x.popitem()
|
||||
return item, x
|
||||
@ -2992,7 +2996,3 @@ class TestScriptList(JitTestCase):
|
||||
for i in range(300):
|
||||
test = Test()
|
||||
test_script = torch.jit.script(test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestLogging(JitTestCase):
|
||||
def test_bump_numeric_counter(self):
|
||||
class ModuleThatLogs(torch.jit.ScriptModule):
|
||||
@ -115,7 +122,3 @@ class TestLogging(JitTestCase):
|
||||
def test_logging_levels_set(self):
|
||||
torch._C._jit_set_logging_option("foo")
|
||||
self.assertEqual("foo", torch._C._jit_get_logging_option())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,7 +12,7 @@ import torch.testing._internal.jit_utils
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
from torch import jit
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import freeze_rng_state
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
|
||||
|
||||
|
||||
@ -20,6 +20,13 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestMisc(JitTestCase):
|
||||
def test_joined_str(self):
|
||||
@ -122,7 +129,7 @@ class TestMisc(JitTestCase):
|
||||
def test_subexpression_Tuple_int_int_Future(self):
|
||||
@torch.jit.script
|
||||
def fn(
|
||||
x: Tuple[int, int, torch.jit.Future[int]],
|
||||
x: Tuple[int, int, torch.jit.Future[int]]
|
||||
) -> Tuple[int, torch.jit.Future[int]]:
|
||||
return x[0], x[2]
|
||||
|
||||
@ -140,7 +147,7 @@ class TestMisc(JitTestCase):
|
||||
def test_subexpression_Optional(self):
|
||||
@torch.jit.script
|
||||
def fn(
|
||||
x: Optional[Dict[int, torch.jit.Future[int]]],
|
||||
x: Optional[Dict[int, torch.jit.Future[int]]]
|
||||
) -> Optional[torch.jit.Future[int]]:
|
||||
if x is not None:
|
||||
return x[0]
|
||||
@ -497,7 +504,3 @@ class TestMisc(JitTestCase):
|
||||
self.assertTrue(len(complex_indices) > 0)
|
||||
self.assertTrue(len(Scalar_indices) > 0)
|
||||
self.assertTrue(complex_indices[0] > Scalar_indices[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,19 +11,24 @@ from torch.testing._internal.common_utils import (
|
||||
enable_profiling_mode_for_profiling_tests,
|
||||
GRAPH_EXECUTOR,
|
||||
ProfilingMode,
|
||||
raise_on_run_directly,
|
||||
set_default_dtype,
|
||||
slowTest,
|
||||
suppress_warnings,
|
||||
)
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import slowTest, suppress_warnings
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
|
||||
@ -79,7 +84,7 @@ class TestModels(JitTestCase):
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf) x 32 x 32
|
||||
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Tanh()
|
||||
# state size. (nc) x 64 x 64
|
||||
)
|
||||
|
||||
@ -749,7 +754,3 @@ class TestModels(JitTestCase):
|
||||
m = self.createFunctionFromGraph(g)
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
self.assertEqual(outputs, m(*inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -5,7 +5,6 @@ import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -13,6 +12,13 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModuleAPIs(JitTestCase):
|
||||
def test_default_state_dict_methods(self):
|
||||
@ -135,7 +141,3 @@ class TestModuleAPIs(JitTestCase):
|
||||
self.assertFalse(m2.sub.customized_load_state_dict_called)
|
||||
m2.load_state_dict(state_dict)
|
||||
self.assertTrue(m2.sub.customized_load_state_dict_called)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,7 +7,6 @@ from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -15,6 +14,13 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModuleContainers(JitTestCase):
|
||||
def test_sequential_intermediary_types(self):
|
||||
@ -750,7 +756,3 @@ class TestModuleContainers(JitTestCase):
|
||||
)
|
||||
|
||||
self.checkModule(MyModule(), (torch.ones(1),))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,7 +8,6 @@ from typing import Any, List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -16,6 +15,13 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class OrigModule(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
@ -695,7 +701,3 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Could not compile"):
|
||||
scripted_mod = torch.jit.script(TestModule())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,7 +4,6 @@ import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -12,6 +11,13 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModules(JitTestCase):
|
||||
def test_script_module_with_constants_list(self):
|
||||
@ -30,7 +36,3 @@ class TestModules(JitTestCase):
|
||||
self.x = 0
|
||||
|
||||
self.checkModule(Net(), (torch.randn(5),))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,10 +2,17 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestOpDecompositions(JitTestCase):
|
||||
def test_op_decomposition(self):
|
||||
def foo(x):
|
||||
@ -35,7 +42,3 @@ class TestOpDecompositions(JitTestCase):
|
||||
FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph)
|
||||
x = torch.rand([4])
|
||||
self.assertEqual(foo(x), torch.square(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly, skipIfNoXNNPACK
|
||||
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -263,7 +263,3 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
conv2d_activation=F.relu,
|
||||
conv2d_activation_kind="aten::relu",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,10 +4,17 @@
|
||||
import torch
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
from torch import nn
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestParametrization(JitTestCase):
|
||||
# Define some parametrization
|
||||
class Symmetric(nn.Module):
|
||||
@ -61,7 +68,3 @@ class TestParametrization(JitTestCase):
|
||||
# Check the scripting process throws an error when caching
|
||||
with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"):
|
||||
scripted_model = torch.jit.trace_module(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
|
||||
from torch.testing._internal.common_utils import NoTest, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import NoTest
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -21,6 +21,13 @@ if not _IS_MONKEYTYPE_INSTALLED:
|
||||
)
|
||||
JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPDT(JitTestCase):
|
||||
"""
|
||||
@ -889,7 +896,3 @@ class TestPDT(JitTestCase):
|
||||
torch.ones(1),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,10 +6,17 @@ from typing import Callable, List
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPeephole(JitTestCase):
|
||||
def test_peephole_with_writes(self):
|
||||
def test_write(x):
|
||||
@ -883,7 +890,3 @@ class TestPeephole(JitTestCase):
|
||||
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check("aten::slice").run(foo.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,10 +4,7 @@ import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -16,6 +13,14 @@ sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestProfiler(JitTestCase):
|
||||
def setUp(self):
|
||||
@ -279,7 +284,3 @@ class TestProfiler(JitTestCase):
|
||||
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -2,10 +2,17 @@
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TestPythonBindings\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPythonBindings(JitTestCase):
|
||||
def test_cu_get_functions(self):
|
||||
@torch.jit.script
|
||||
@ -107,7 +114,3 @@ graph(%p207 : Tensor,
|
||||
graph3 = torch._C.parse_ir(ir)
|
||||
graph3 = torch._C._jit_pass_canonicalize(graph3, False)
|
||||
FileCheck().check_not("%p207").run(graph3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,7 +7,6 @@ import tempfile
|
||||
from textwrap import dedent
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
|
||||
@ -15,6 +14,13 @@ from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def get_fn(file_name, script_path):
|
||||
import importlib.util
|
||||
@ -467,7 +473,3 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
|
||||
s = torch.rand(1)
|
||||
self.assertTrue(foo(s))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,10 +6,18 @@ import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPythonIr(JitTestCase):
|
||||
def test_param_strides(self):
|
||||
def trace_me(arg):
|
||||
@ -92,7 +100,3 @@ class TestPythonIr(JitTestCase):
|
||||
|
||||
FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
|
||||
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -20,13 +20,20 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import (
|
||||
_tmp_donotuse_dont_inline_everything,
|
||||
JitTestCase,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestRecursiveScript(JitTestCase):
|
||||
def test_inferred_nonetype(self):
|
||||
class M(nn.Module):
|
||||
@ -792,7 +799,3 @@ class TestRecursiveScript(JitTestCase):
|
||||
# ScriptModule should correctly reflect the override.
|
||||
s = torch.jit.script(m)
|
||||
self.assertEqual(s.i_am_ignored(), "new")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,10 +11,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestRemoveMutation(JitTestCase):
|
||||
def test_aten_inplace(self):
|
||||
def test_not_new_alias(x):
|
||||
@ -311,7 +318,3 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
self.run_pass("remove_mutation", mod_script.forward.graph)
|
||||
FileCheck().check("aten::add_").run(test_multiple_uses.graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,11 +8,7 @@ from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
TemporaryFileName,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -21,6 +17,14 @@ sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestSaveLoad(JitTestCase):
|
||||
def test_different_modules(self):
|
||||
"""
|
||||
@ -1193,7 +1197,3 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)
|
||||
|
||||
self.assertEqual(extra_files, re_extra_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -17,10 +17,17 @@ import torch
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestSaveLoadForOpVersion(JitTestCase):
|
||||
# Helper that returns the module after saving and loading
|
||||
def _save_load_module(self, m):
|
||||
@ -610,7 +617,3 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
self.assertTrue(output.size(dim=0) == 100)
|
||||
# "Upgraded" model should match the new version output
|
||||
self.assertEqual(output, output_current)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ from torch import nn
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class Sequence(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -108,7 +115,3 @@ class TestScriptProfile(JitTestCase):
|
||||
p.enable()
|
||||
p.disable()
|
||||
self.assertEqual(p.dump_string(), "")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -11,10 +11,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
# NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
|
||||
# reassigning a non-empty Tuple to an attribute previously typed
|
||||
@ -356,7 +363,3 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests that Python slice class is supported in TorchScript
|
||||
class TestSlice(JitTestCase):
|
||||
def test_slice_kwarg(self):
|
||||
@ -171,7 +178,3 @@ class TestSlice(JitTestCase):
|
||||
self.assertEqual(result2[0].identifier, "B")
|
||||
self.assertEqual(result2[1].identifier, "C")
|
||||
self.assertEqual(result2[2].identifier, "D")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -4,11 +4,7 @@ import io
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
TEST_MKL,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -122,7 +118,3 @@ class TestSparse(JitTestCase):
|
||||
loaded_result = loaded_model.forward(x)
|
||||
|
||||
self.assertEqual(expected_result.to_dense(), loaded_result.to_dense())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestStringFormatting(JitTestCase):
|
||||
def test_modulo_operator(self):
|
||||
def fn(dividend: int, divisor: int) -> int:
|
||||
@ -192,7 +199,3 @@ class TestStringFormatting(JitTestCase):
|
||||
'"%a in template" % arg1',
|
||||
):
|
||||
fn("foo")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -9,10 +9,18 @@ import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
|
||||
from torch.testing._internal.common_utils import make_tensor, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import make_tensor
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# XXX: still in prototype
|
||||
class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
def setUp(self):
|
||||
@ -811,7 +819,3 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
input.setType(input.type().with_sizes([1, 5, 8]))
|
||||
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
|
||||
self.assertEqual(next(foo.graph.outputs()).type().symbolic_sizes(), [5, 8])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -9,10 +9,17 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTensorCreationOps(JitTestCase):
|
||||
"""
|
||||
A suite of tests for ops that create tensors.
|
||||
@ -71,7 +78,3 @@ class TestTensorCreationOps(JitTestCase):
|
||||
assert indices.dtype == torch.int32
|
||||
|
||||
self.checkScript(tril_indices, (3, 3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,17 @@ import torch
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTensorMethods(JitTestCase):
|
||||
def test_getitem(self):
|
||||
def tensor_getitem(inp: torch.Tensor):
|
||||
@ -34,7 +41,3 @@ class TestTensorMethods(JitTestCase):
|
||||
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"
|
||||
):
|
||||
torch.jit.script(tensor_getitem_invalid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -8,10 +8,7 @@ import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -22,6 +19,14 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.torchbind_impls import load_torchbind_test_lib
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("skipping as a precaution")
|
||||
class TestTorchbind(JitTestCase):
|
||||
def setUp(self):
|
||||
@ -458,7 +463,3 @@ class TestTorchbind(JitTestCase):
|
||||
return obj.decrement()
|
||||
|
||||
self.checkScript(gn, ())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -29,7 +29,6 @@ from torch.testing._internal.common_cuda import with_tf32_off
|
||||
from torch.testing._internal.common_utils import (
|
||||
enable_profiling_mode_for_profiling_tests,
|
||||
IS_SANDCASTLE,
|
||||
raise_on_run_directly,
|
||||
skipIfCompiledWithoutNumpy,
|
||||
skipIfCrossRef,
|
||||
skipIfTorchDynamo,
|
||||
@ -47,6 +46,14 @@ from torch.testing._internal.jit_utils import (
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
class TestTracer(JitTestCase):
|
||||
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
||||
@ -2819,7 +2826,3 @@ class TestMixTracingScripting(JitTestCase):
|
||||
for n in fn_t.graph.nodes():
|
||||
if n.kind() == "prim::CallFunction":
|
||||
self.assertTrue(n.output().isCompleteTensor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,13 +10,18 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
suppress_warnings,
|
||||
)
|
||||
from torch.testing._internal.common_utils import suppress_warnings
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTypeSharing(JitTestCase):
|
||||
def assertSameType(self, m1, m2):
|
||||
if not isinstance(m1, torch.jit.ScriptModule):
|
||||
@ -621,7 +626,3 @@ class TestTypeSharing(JitTestCase):
|
||||
# of A, __jit_ignored_attributes__ was modified before scripting s2,
|
||||
# so the set of ignored attributes is different between s1 and s2.
|
||||
self.assertDifferentType(s1, s2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -12,7 +12,6 @@ import torch
|
||||
import torch.testing._internal.jit_utils
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
@ -20,6 +19,13 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTypesAndAnnotation(JitTestCase):
|
||||
def test_pep585_type(self):
|
||||
@ -364,7 +370,3 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "ErrorReason"):
|
||||
t = inferred_type.type()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -7,7 +7,7 @@ from collections import namedtuple
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -15,6 +15,13 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTyping(JitTestCase):
|
||||
def test_dict_in_not_in(self):
|
||||
@ -133,7 +140,7 @@ class TestTyping(JitTestCase):
|
||||
|
||||
# Check for invalid key and value type annotation
|
||||
def wrong_key_value_type(
|
||||
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule],
|
||||
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]
|
||||
):
|
||||
return
|
||||
|
||||
@ -681,7 +688,3 @@ class TestTyping(JitTestCase):
|
||||
mod2 = LowestModule()
|
||||
mod_s = torch.jit.script(mod)
|
||||
mod2_s = torch.jit.script(mod2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -15,10 +15,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestUnion(JitTestCase):
|
||||
"""
|
||||
This class tests the functionality of `Union`.
|
||||
@ -1059,7 +1066,3 @@ class TestUnion(JitTestCase):
|
||||
# "Union[Dict[str, torch.Tensor], int]",
|
||||
# lhs["dict_comprehension_of_mixed"],
|
||||
# "foobar")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -16,10 +16,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "Requires Python 3.10")
|
||||
class TestUnion(JitTestCase):
|
||||
"""
|
||||
@ -1057,7 +1064,3 @@ class TestUnion(JitTestCase):
|
||||
# "Dict[str, torch.Tensor] | int",
|
||||
# lhs["dict_comprehension_of_mixed"],
|
||||
# "foobar")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -10,10 +10,16 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
# NOTE: FIXING FAILING TESTS
|
||||
# If you are seeing a test failure from this file, congrats, you improved
|
||||
# parity between JIT and Python API. Before you fix the test, you must also update
|
||||
@ -84,7 +90,3 @@ class TestUnsupportedOps(JitTestCase):
|
||||
func()
|
||||
with self.assertRaisesRegex(Exception, ""):
|
||||
torch.jit.script(func)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,10 +13,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestUpgraders(JitTestCase):
|
||||
def _load_model_version(self, loaded_model):
|
||||
buffer = io.BytesIO()
|
||||
@ -339,7 +346,3 @@ class TestUpgraders(JitTestCase):
|
||||
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
|
||||
version = self._load_model_version(loaded_model)
|
||||
self.assertTrue(version == 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -13,10 +13,17 @@ from torch.testing import FileCheck
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestWarn(JitTestCase):
|
||||
def test_warn(self):
|
||||
@torch.jit.script
|
||||
@ -141,7 +148,3 @@ class TestWarn(JitTestCase):
|
||||
).run(
|
||||
f.getvalue()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -6,10 +6,7 @@ import sys
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
@ -17,6 +14,13 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestWith(JitTestCase):
|
||||
"""
|
||||
@ -643,7 +647,3 @@ class TestWith(JitTestCase):
|
||||
# Nested record function should have child "aten::add"
|
||||
nested_child_events = nested_function_event.cpu_children
|
||||
self.assertTrue("aten::add" in (child.name for child in nested_child_events))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_jit.py")
|
||||
|
@ -184,10 +184,3 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test is not currently used and should be "
|
||||
"enabled in discover_tests.py if required."
|
||||
)
|
||||
|
@ -871,11 +871,6 @@ def enable_profiling_mode_for_profiling_tests():
|
||||
torch._C._jit_set_profiling_executor(old_prof_exec_state)
|
||||
torch._C._get_graph_executor_optimize(old_prof_mode_state)
|
||||
|
||||
def raise_on_run_directly(file_to_call):
|
||||
raise RuntimeError("This test file is not meant to be run directly, "
|
||||
f"use:\n\n\tpython {file_to_call} TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
@contextmanager
|
||||
def enable_profiling_mode():
|
||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||
|
Reference in New Issue
Block a user