mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	[test][scan] refactor inductor test and prepare for adding bw tests
[ghstack-poisoned]
This commit is contained in:
		| @ -1,5 +1,5 @@ | ||||
| # Owner(s): ["module: inductor"] | ||||
| import contextlib | ||||
|  | ||||
| import itertools | ||||
| import unittest | ||||
|  | ||||
| @ -1755,27 +1755,56 @@ class ScanTests(TestCase): | ||||
|         inputs, | ||||
|         device, | ||||
|         dynamic, | ||||
|         requires_grad=False, | ||||
|         autograd=False, | ||||
|     ): | ||||
|         cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") | ||||
|         compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)( | ||||
|             model | ||||
|         ) | ||||
|         import copy | ||||
|  | ||||
|         inputs = [inp.requires_grad_(autograd) for inp in inputs] | ||||
|         inputs = [inp.to(device=device) for inp in inputs] | ||||
|         model = model.to(device=device) | ||||
|         cloned_inputs = [inp.clone() for inp in inputs] | ||||
|         grad_ctx = contextlib.nullcontext() if requires_grad else torch.no_grad() | ||||
|         with grad_ctx: | ||||
|             result = model(scan, *cloned_inputs) | ||||
|             result_exp = model(_fake_scan, *cloned_inputs) | ||||
|         for p in model.parameters(): | ||||
|             p.requires_grad_(autograd) | ||||
|  | ||||
|             result_compiled = compiled_model(scan, *cloned_inputs) | ||||
|             result_compiled_exp = compiled_model(_fake_scan, *cloned_inputs) | ||||
|         model1 = copy.deepcopy(model) | ||||
|         model2 = copy.deepcopy(model) | ||||
|         model3 = copy.deepcopy(model) | ||||
|         model4 = copy.deepcopy(model) | ||||
|         torch.compile(fullgraph=True, dynamic=dynamic)(model) | ||||
|  | ||||
|         self.assertEqual(result, result_exp) | ||||
|         def _run_model(model, inputs): | ||||
|             cloned_inputs = [ | ||||
|                 inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs | ||||
|             ] | ||||
|             fw_result = model(*cloned_inputs) | ||||
|             loss = loss_fn(fw_result) | ||||
|             if autograd: | ||||
|                 loss.backward() | ||||
|                 return ( | ||||
|                     fw_result, | ||||
|                     loss, | ||||
|                     [ | ||||
|                         inp.grad | ||||
|                         for inp in cloned_inputs | ||||
|                         if isinstance(inp, torch.Tensor) | ||||
|                     ], | ||||
|                     {n: p.grad for n, p in model.named_parameters()}, | ||||
|                 ) | ||||
|             else: | ||||
|                 return fw_result, loss | ||||
|  | ||||
|         result_exp = _run_model(model1, [_fake_scan] + inputs) | ||||
|         result_eager = _run_model(model2, [scan] + inputs) | ||||
|         result_compiled = _run_model( | ||||
|             torch.compile(fullgraph=True, dynamic=dynamic)(model3), [scan] + inputs | ||||
|         ) | ||||
|         result_compiled_exp = _run_model( | ||||
|             torch.compile(fullgraph=True, dynamic=dynamic)(model4), | ||||
|             [_fake_scan] + inputs, | ||||
|         ) | ||||
|  | ||||
|         self.assertEqual(result_exp, result_eager) | ||||
|         self.assertEqual(result_exp, result_compiled) | ||||
|         self.assertEqual(result_compiled, result_compiled_exp) | ||||
|         self.assertEqual(result_exp, result_compiled_exp) | ||||
|  | ||||
|     def _compare_result( | ||||
|         self, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user