mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			172 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["oncall: jit"]
 | |
| 
 | |
| import os
 | |
| import sys
 | |
| 
 | |
| import torch
 | |
| from typing import List
 | |
| 
 | |
| # Make the helper files in test/ importable
 | |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
 | |
| sys.path.append(pytorch_test_dir)
 | |
| from torch.testing._internal.jit_utils import JitTestCase
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
 | |
|                        "\tpython test/test_jit.py TESTNAME\n\n"
 | |
|                        "instead.")
 | |
| 
 | |
| # Tests that Python slice class is supported in TorchScript
 | |
| class TestSlice(JitTestCase):
 | |
|     def test_slice_kwarg(self):
 | |
|         def slice_kwarg(x: List[int]):
 | |
|             return x[slice(1, stop=2)]
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "Slice does not accept any keyword arguments"):
 | |
|             torch.jit.script(slice_kwarg)
 | |
| 
 | |
|     def test_slice_three_nones(self):
 | |
|         def three_nones(x: List[int]):
 | |
|             return x[slice(None, None, None)]
 | |
| 
 | |
|         self.checkScript(three_nones, (range(10),))
 | |
| 
 | |
|     def test_slice_two_nones(self):
 | |
|         def two_nones(x: List[int]):
 | |
|             return x[slice(None, None)]
 | |
| 
 | |
|         self.checkScript(two_nones, (range(10),))
 | |
| 
 | |
|     def test_slice_one_none(self):
 | |
|         def one_none(x: List[int]):
 | |
|             return x[slice(None)]
 | |
| 
 | |
|         self.checkScript(one_none, (range(10),))
 | |
| 
 | |
|     def test_slice_stop_only(self):
 | |
|         def fn(x: List[int]):
 | |
|             return x[slice(5)]
 | |
|         self.checkScript(fn, (range(10),))
 | |
| 
 | |
|     def test_slice_stop_only_with_nones(self):
 | |
|         def fn(x: List[int]):
 | |
|             return x[slice(None, 5, None)]
 | |
|         self.checkScript(fn, (range(10),))
 | |
| 
 | |
|     def test_slice_start_stop(self):
 | |
|         def fn(x: List[int]):
 | |
|             return x[slice(1, 5)]
 | |
| 
 | |
|         self.checkScript(fn, (range(10),))
 | |
| 
 | |
|     def test_slice_start_stop_with_none(self):
 | |
|         def fn(x: List[int]):
 | |
|             return x[slice(1, 5, None)]
 | |
| 
 | |
|         self.checkScript(fn, (range(10),))
 | |
| 
 | |
|     def test_slice_start_stop_step(self):
 | |
|         def fn(x: List[int]):
 | |
|             return x[slice(0, 6, 2)]
 | |
| 
 | |
|         self.checkScript(fn, (range(10),))
 | |
| 
 | |
|     def test_slice_string(self):
 | |
|         def fn(x: str):
 | |
|             return x[slice(None, 3, 1)]
 | |
| 
 | |
|         self.checkScript(fn, ("foo_bar",))
 | |
| 
 | |
|     def test_slice_tensor(self):
 | |
|         def fn(x: torch.Tensor):
 | |
|             return x[slice(None, 3, 1)]
 | |
| 
 | |
|         self.checkScript(fn, (torch.ones(10),))
 | |
| 
 | |
|     def test_slice_tensor_multidim(self):
 | |
|         def fn(x: torch.Tensor):
 | |
|             return x[slice(None, 3, 1), 0]
 | |
| 
 | |
|         self.checkScript(fn, (torch.ones((10, 10)),))
 | |
| 
 | |
|     def test_slice_tensor_multidim_with_dots(self):
 | |
|         def fn(x: torch.Tensor):
 | |
|             return x[slice(None, 3, 1), ...]
 | |
| 
 | |
|         self.checkScript(fn, (torch.ones((10, 10)),))
 | |
| 
 | |
|     def test_slice_as_variable(self):
 | |
|         def fn(x: List[int]):
 | |
|             a = slice(1)
 | |
|             return x[a]
 | |
| 
 | |
|         self.checkScript(fn, (range(10),))
 | |
| 
 | |
|     def test_slice_stop_clipped(self):
 | |
|         def fn(x: List[int]):
 | |
|             return x[slice(1000)]
 | |
| 
 | |
|         self.checkScript(fn, (range(10),))
 | |
| 
 | |
|     def test_slice_dynamic_index(self):
 | |
|         def t(x):
 | |
|             slice1 = x[0:1]
 | |
|             zero = 0
 | |
|             one = zero + 1
 | |
|             slice2 = x[zero:one]
 | |
|             return slice1 + slice2
 | |
| 
 | |
|         self.checkScript(t, (torch.zeros(3, 2, 3),))
 | |
| 
 | |
|     def test_tuple_slicing(self):
 | |
|         def tuple_slice(a):
 | |
|             if bool(a):
 | |
|                 b = (1, 2, 3, 4)
 | |
|             else:
 | |
|                 b = (4, 3, 2, 1)
 | |
|             c = b[-4:4]
 | |
|             e = c[1:-1]
 | |
|             return e
 | |
| 
 | |
|         self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
 | |
|         scripted_fn = torch.jit.script(tuple_slice)
 | |
|         self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
 | |
|         tuple_graph = scripted_fn.graph
 | |
|         slices = tuple_graph.findAllNodes("prim::TupleConstruct")
 | |
|         num_outputs = {len(x.output().type().elements()) for x in slices}
 | |
|         # there should be only one tupleSlice with length of 2
 | |
|         self.assertTrue(num_outputs == {2})
 | |
|         self.run_pass('lower_all_tuples', tuple_graph)
 | |
|         self.assertTrue('Tuple' not in str(tuple_graph))
 | |
| 
 | |
|     def test_module_list_slicing(self):
 | |
|         class Bar(torch.nn.Module):
 | |
|             def __init__(self, identifier: str):
 | |
|                 super().__init__()
 | |
|                 self.identifier = identifier
 | |
| 
 | |
|             def forward(self):
 | |
|                 return 0
 | |
| 
 | |
|         class Foo(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")]
 | |
|                 self.test = torch.nn.ModuleList(module_list)
 | |
| 
 | |
|             def forward(self):
 | |
|                 return self.test[::-2], self.test[1:4:]
 | |
| 
 | |
|         scripted_foo = torch.jit.script(Foo())
 | |
|         result1, result2 = scripted_foo()
 | |
| 
 | |
|         self.assertEqual(len(result1), 3)
 | |
|         self.assertEqual(result1[0].identifier, "E")
 | |
|         self.assertEqual(result1[1].identifier, "C")
 | |
|         self.assertEqual(result1[2].identifier, "A")
 | |
| 
 | |
|         self.assertEqual(len(result2), 3)
 | |
|         self.assertEqual(result2[0].identifier, "B")
 | |
|         self.assertEqual(result2[1].identifier, "C")
 | |
|         self.assertEqual(result2[2].identifier, "D")
 |