mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Summary: **Summary** This commit modifies the JIT frontend to handle `del` statements with variables as targets by dropping the mapping corresponding to that variable from the environment stack maintained by the IR emitter code. **Test Plan** This commit adds test cases for deleting a variable, deleting a variable and then using it, and deleting a variable in a if-statement, and then using it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/37608 Differential Revision: D21507239 Pulled By: SplitInfinity fbshipit-source-id: ac7e353817dc76990ece294c95965cf585d6bdfb
		
			
				
	
	
		
			168 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			168 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import sys
 | |
| import inspect
 | |
| from typing import List
 | |
| 
 | |
| import torch
 | |
| 
 | |
| # Make the helper files in test/ importable
 | |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
 | |
| sys.path.append(pytorch_test_dir)
 | |
| from torch.testing._internal.jit_utils import JitTestCase
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
 | |
|                        "\tpython test/test_jit.py TESTNAME\n\n"
 | |
|                        "instead.")
 | |
| 
 | |
| class TestBuiltins(JitTestCase):
 | |
|     """
 | |
|     Tests for TorchScript support of Python builtin functions.
 | |
|     """
 | |
|     def test_has_attr(self):
 | |
|         class HasA(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(HasA, self).__init__()
 | |
|                 self.a = 0
 | |
| 
 | |
|         class HasB(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(HasB, self).__init__()
 | |
|                 self.b = 1
 | |
| 
 | |
|         class Mod(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(Mod, self).__init__()
 | |
|                 self.mods = torch.nn.ModuleList([HasA(), HasB()])
 | |
| 
 | |
|             def forward(self):
 | |
|                 # use a list to encode hasattr results
 | |
|                 l = torch.jit.annotate(List[int], [])
 | |
|                 for mod in self.mods:
 | |
|                     l.append(int(hasattr(mod, "a")))
 | |
|                     l.append(int(hasattr(mod, "b")))
 | |
|                     # actually retrieve the attr to test static refinement
 | |
|                     if hasattr(mod, "a"):
 | |
|                         l.append(mod.a)
 | |
|                     if hasattr(mod, "b"):
 | |
|                         l.append(mod.b)
 | |
|                 return l
 | |
| 
 | |
|         self.checkModule(Mod(), ())
 | |
| 
 | |
|     def test_has_attr_invalid_args(self):
 | |
|         class Mod(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(Mod, self).__init__()
 | |
|                 self.mod = torch.nn.Linear(1, 1)
 | |
| 
 | |
|             def forward(self, name):
 | |
|                 # not allowed, `name` must be static.
 | |
|                 return hasattr(self.mod, name)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "hasattr"):
 | |
|             torch.jit.script(Mod())
 | |
| 
 | |
|         class Mod(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(Mod, self).__init__()
 | |
| 
 | |
|             def forward(self, name):
 | |
|                 # not allowed, `torch.rand` is not a class type
 | |
|                 return hasattr(torch.rand(2, 3), name)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "hasattr"):
 | |
|             torch.jit.script(Mod())
 | |
| 
 | |
|     def test_del(self):
 | |
|         def fn(x):
 | |
|             # type: (List[int]) -> List[int]
 | |
|             a = x * 2
 | |
|             del a
 | |
|             return x
 | |
| 
 | |
|         self.checkScript(fn, ([1, 2, 3],))
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "undefined value"):
 | |
|             @torch.jit.script
 | |
|             def fn(x):
 | |
|                 a = x ** 2
 | |
|                 del a
 | |
|                 return a
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "undefined value"):
 | |
|             @torch.jit.script
 | |
|             def fn(x):
 | |
|                 a = x ** 2
 | |
|                 if a:
 | |
|                     del a
 | |
|                 return a
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "undefined value"):
 | |
|             @torch.jit.script
 | |
|             def fn(x):
 | |
|                 a = x ** 2
 | |
|                 del b
 | |
|                 return a
 | |
| 
 | |
|     def test_del_multiple_operands(self):
 | |
| 
 | |
|         with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
 | |
|                                     "with more than one operand"):
 | |
|             @torch.jit.script
 | |
|             def del_list_multiple_operands(x):
 | |
|                 # type: (List[int]) -> List[int]
 | |
|                 del x[0], x[1]
 | |
|                 return x
 | |
| 
 | |
|         with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
 | |
|                                     "with more than one operand"):
 | |
|             @torch.jit.script
 | |
|             def del_dict_multiple_operands(x):
 | |
|                 # type: (Dict[str, int]) -> Dict[str, int]
 | |
|                 del x['hi'], x['there']
 | |
|                 return x
 | |
| 
 | |
| 
 | |
| class TestTensorBuiltins(JitTestCase):
 | |
|     def test_tensor_properties(self):
 | |
|         def should_keep(tensor, name):
 | |
|             if inspect.isroutine(getattr(tensor, name)):
 | |
|                 return False
 | |
|             if name.startswith('_'):
 | |
|                 return False
 | |
|             return True
 | |
| 
 | |
|         tensor = torch.arange(4, dtype=torch.float).view(2, 2)
 | |
|         properties = [p for p in dir(tensor) if should_keep(tensor, p)]
 | |
| 
 | |
|         code_template = """
 | |
|         def fn(x):
 | |
|             return x.{}
 | |
|         """
 | |
| 
 | |
|         EQUALITY_MISMATCH = set([
 | |
|             # TorchScript doesn't have real enums so they return an int instead
 | |
|             # of the actual value
 | |
|             'dtype',
 | |
|             'layout',
 | |
|         ])
 | |
|         MISSING_PROPERTIES = set([
 | |
|             'grad_fn',
 | |
|             # This is an undocumented property so it's not included
 | |
|             "output_nr",
 | |
|             # This has a longer implementation, maybe not worth copying to
 | |
|             # TorchScript if named tensors don't work there anyways
 | |
|             'names',
 | |
|         ])
 | |
| 
 | |
|         for p in properties:
 | |
|             if p in MISSING_PROPERTIES:
 | |
|                 continue
 | |
|             code = code_template.format(p)
 | |
|             cu = torch.jit.CompilationUnit()
 | |
|             cu.define(code)
 | |
|             if p in EQUALITY_MISMATCH:
 | |
|                 continue
 | |
|             self.assertEqual(getattr(tensor, p), cu.fn(tensor))
 |