mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	TST Adds inplace checks to module_info (#63739)
Summary: Follow up to https://github.com/pytorch/pytorch/pull/61935 This PR adds inplace checks to `test_modules`. This version checks the constructor for `inplace` and performs the check automatically. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63739 Reviewed By: saketh-are Differential Revision: D30737774 Pulled By: jbschlosser fbshipit-source-id: 8813534511e9296c8424d1ca878412726ddd4043
This commit is contained in:
		
				
					committed by
					
						
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							a5ad08ec70
						
					
				
				
					commit
					43c0f033fc
				
			@ -1,3 +1,5 @@
 | 
			
		||||
from inspect import signature
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import tempfile
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -154,6 +156,55 @@ class TestModule(TestCase):
 | 
			
		||||
                    output_from_copy = m_copy(*args, **kwargs)
 | 
			
		||||
                    self.assertEqual(output, output_from_copy)
 | 
			
		||||
 | 
			
		||||
    @modules([module_info for module_info in module_db
 | 
			
		||||
              if 'inplace' in signature(module_info.module_cls).parameters])
 | 
			
		||||
    def test_check_inplace(self, device, dtype, module_info):
 | 
			
		||||
        # Check if the inplace variant of the module gives the same result as the out of place
 | 
			
		||||
        # variant.
 | 
			
		||||
        module_cls = module_info.module_cls
 | 
			
		||||
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | 
			
		||||
                                                       requires_grad=True)
 | 
			
		||||
        for module_input in module_inputs:
 | 
			
		||||
            if module_input.forward_input is None:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # === Instantiate the module. ===
 | 
			
		||||
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | 
			
		||||
            m_op = module_cls(*args, **kwargs, inplace=False)
 | 
			
		||||
            m_op.to(device).to(dtype)
 | 
			
		||||
            m_inplace = module_cls(*args, **kwargs, inplace=True)
 | 
			
		||||
            m_inplace.to(device).to(dtype)
 | 
			
		||||
 | 
			
		||||
            # === Inplace modules only supports inplace operations on the first argument ===
 | 
			
		||||
            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | 
			
		||||
 | 
			
		||||
            # ===  Do not allow the first input to be in input_kwargs ===
 | 
			
		||||
            forward_sig = signature(m_op).parameters
 | 
			
		||||
            self.assertGreaterEqual(len(forward_sig), 1)
 | 
			
		||||
            first_param_name = next(iter(forward_sig.items()))
 | 
			
		||||
            self.assertNotIn(first_param_name, input_kwargs)
 | 
			
		||||
 | 
			
		||||
            # === Out of place operation does not write to original tensor ===
 | 
			
		||||
            self.assertGreaterEqual(len(input_args), 1)
 | 
			
		||||
            input_version = input_args[0]._version
 | 
			
		||||
            with freeze_rng_state():
 | 
			
		||||
                output_op = m_op(*input_args, **input_kwargs)
 | 
			
		||||
            self.assertEqual(input_args[0]._version, input_version)
 | 
			
		||||
 | 
			
		||||
            # === Check that the inplace operation gives the same result ===
 | 
			
		||||
            input_arg_copy = deepcopy(input_args)
 | 
			
		||||
            input_arg_clone = tuple(i.clone() for i in input_arg_copy)
 | 
			
		||||
            with freeze_rng_state():
 | 
			
		||||
                output_ip = m_inplace(*input_arg_clone, **input_kwargs)
 | 
			
		||||
            self.assertNotEqual(input_arg_clone[0]._version, input_version)
 | 
			
		||||
            self.assertEqual(output_op, output_ip)
 | 
			
		||||
 | 
			
		||||
            # === Check that the gradients are the same ===
 | 
			
		||||
            grad = output_op.data.clone().normal_()
 | 
			
		||||
            output_op.backward(grad)
 | 
			
		||||
            output_ip.backward(grad)
 | 
			
		||||
            self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
instantiate_device_type_tests(TestModule, globals())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user