mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665 Approved by: https://github.com/albanD
		
			
				
	
	
		
			39 lines
		
	
	
		
			926 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			39 lines
		
	
	
		
			926 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: dynamo"]
 | |
| 
 | |
| import torch
 | |
| import torch._dynamo.test_case
 | |
| 
 | |
| 
 | |
| def fn_creator():
 | |
|     var1 = 1
 | |
| 
 | |
|     def fn(x):
 | |
|         x = x + 1
 | |
|         var2 = 1
 | |
|         torch._dynamo.graph_break()
 | |
|         x = x + var1
 | |
| 
 | |
|         def inner_fn():  # noqa: F841
 | |
|             return var2
 | |
| 
 | |
|         return x
 | |
| 
 | |
|     return fn
 | |
| 
 | |
| 
 | |
| class ResumeFunctionTests(torch._dynamo.test_case.TestCase):
 | |
|     def test_freevars(self):
 | |
|         fn = fn_creator()
 | |
|         opt_fn = torch.compile(fn, backend="eager")
 | |
|         opt_fn(torch.randn(10))
 | |
|         codes = [v for k, v in list(globals().items()) if k.startswith("__resume_at")]
 | |
|         self.assertEqual(len(codes), 1)
 | |
|         # co_freevars of resume functions, are sorted concatenation of the original function's co_freevars and co_cellvars
 | |
|         self.assertEqual(codes[0].co_freevars, ("var1", "var2"))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     from torch._dynamo.test_case import run_tests
 | |
| 
 | |
|     run_tests()
 |