mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/clee2000
178 lines
5.1 KiB
Python
178 lines
5.1 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
# Tests that Python slice class is supported in TorchScript
|
|
class TestSlice(JitTestCase):
|
|
def test_slice_kwarg(self):
|
|
def slice_kwarg(x: List[int]):
|
|
return x[slice(1, stop=2)]
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Slice does not accept any keyword arguments"
|
|
):
|
|
torch.jit.script(slice_kwarg)
|
|
|
|
def test_slice_three_nones(self):
|
|
def three_nones(x: List[int]):
|
|
return x[slice(None, None, None)]
|
|
|
|
self.checkScript(three_nones, (range(10),))
|
|
|
|
def test_slice_two_nones(self):
|
|
def two_nones(x: List[int]):
|
|
return x[slice(None, None)]
|
|
|
|
self.checkScript(two_nones, (range(10),))
|
|
|
|
def test_slice_one_none(self):
|
|
def one_none(x: List[int]):
|
|
return x[slice(None)]
|
|
|
|
self.checkScript(one_none, (range(10),))
|
|
|
|
def test_slice_stop_only(self):
|
|
def fn(x: List[int]):
|
|
return x[slice(5)]
|
|
|
|
self.checkScript(fn, (range(10),))
|
|
|
|
def test_slice_stop_only_with_nones(self):
|
|
def fn(x: List[int]):
|
|
return x[slice(None, 5, None)]
|
|
|
|
self.checkScript(fn, (range(10),))
|
|
|
|
def test_slice_start_stop(self):
|
|
def fn(x: List[int]):
|
|
return x[slice(1, 5)]
|
|
|
|
self.checkScript(fn, (range(10),))
|
|
|
|
def test_slice_start_stop_with_none(self):
|
|
def fn(x: List[int]):
|
|
return x[slice(1, 5, None)]
|
|
|
|
self.checkScript(fn, (range(10),))
|
|
|
|
def test_slice_start_stop_step(self):
|
|
def fn(x: List[int]):
|
|
return x[slice(0, 6, 2)]
|
|
|
|
self.checkScript(fn, (range(10),))
|
|
|
|
def test_slice_string(self):
|
|
def fn(x: str):
|
|
return x[slice(None, 3, 1)]
|
|
|
|
self.checkScript(fn, ("foo_bar",))
|
|
|
|
def test_slice_tensor(self):
|
|
def fn(x: torch.Tensor):
|
|
return x[slice(None, 3, 1)]
|
|
|
|
self.checkScript(fn, (torch.ones(10),))
|
|
|
|
def test_slice_tensor_multidim(self):
|
|
def fn(x: torch.Tensor):
|
|
return x[slice(None, 3, 1), 0]
|
|
|
|
self.checkScript(fn, (torch.ones((10, 10)),))
|
|
|
|
def test_slice_tensor_multidim_with_dots(self):
|
|
def fn(x: torch.Tensor):
|
|
return x[slice(None, 3, 1), ...]
|
|
|
|
self.checkScript(fn, (torch.ones((10, 10)),))
|
|
|
|
def test_slice_as_variable(self):
|
|
def fn(x: List[int]):
|
|
a = slice(1)
|
|
return x[a]
|
|
|
|
self.checkScript(fn, (range(10),))
|
|
|
|
def test_slice_stop_clipped(self):
|
|
def fn(x: List[int]):
|
|
return x[slice(1000)]
|
|
|
|
self.checkScript(fn, (range(10),))
|
|
|
|
def test_slice_dynamic_index(self):
|
|
def t(x):
|
|
slice1 = x[0:1]
|
|
zero = 0
|
|
one = zero + 1
|
|
slice2 = x[zero:one]
|
|
return slice1 + slice2
|
|
|
|
self.checkScript(t, (torch.zeros(3, 2, 3),))
|
|
|
|
def test_tuple_slicing(self):
|
|
def tuple_slice(a):
|
|
if bool(a):
|
|
b = (1, 2, 3, 4)
|
|
else:
|
|
b = (4, 3, 2, 1)
|
|
c = b[-4:4]
|
|
e = c[1:-1]
|
|
return e
|
|
|
|
self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
|
|
scripted_fn = torch.jit.script(tuple_slice)
|
|
self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
|
|
tuple_graph = scripted_fn.graph
|
|
slices = tuple_graph.findAllNodes("prim::TupleConstruct")
|
|
num_outputs = {len(x.output().type().elements()) for x in slices}
|
|
# there should be only one tupleSlice with length of 2
|
|
self.assertTrue(num_outputs == {2})
|
|
self.run_pass("lower_all_tuples", tuple_graph)
|
|
self.assertTrue("Tuple" not in str(tuple_graph))
|
|
|
|
def test_module_list_slicing(self):
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self, identifier: str):
|
|
super().__init__()
|
|
self.identifier = identifier
|
|
|
|
def forward(self):
|
|
return 0
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")]
|
|
self.test = torch.nn.ModuleList(module_list)
|
|
|
|
def forward(self):
|
|
return self.test[::-2], self.test[1:4:]
|
|
|
|
scripted_foo = torch.jit.script(Foo())
|
|
result1, result2 = scripted_foo()
|
|
|
|
self.assertEqual(len(result1), 3)
|
|
self.assertEqual(result1[0].identifier, "E")
|
|
self.assertEqual(result1[1].identifier, "C")
|
|
self.assertEqual(result1[2].identifier, "A")
|
|
|
|
self.assertEqual(len(result2), 3)
|
|
self.assertEqual(result2[0].identifier, "B")
|
|
self.assertEqual(result2[1].identifier, "C")
|
|
self.assertEqual(result2[2].identifier, "D")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|