mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Initial ModuleInfo implementation (#61935)
Summary:
This PR contains the initial version of `ModuleInfo` for use in testing modules. The design philosophy taken here is to start small and simple and build out / refactor as needed when more test coverage or `ModuleInfo` entries are added. As such, it's not intended for general usage yet. The PR contains the following:
* (new file) `torch/testing/_internal/common_modules.py`
  * `ModuleInfo` definition - metadata for each module to use in testing
  * `module_db` - the actual `ModuleInfo` database; currently contains entries for two modules
  * `ModuleInput` - analogous to `SampleInput` from OpInfo; contains `FunctionInput`s for both constructor and forward pass inputs
      * Constructor and forward pass inputs are tied together within a `ModuleInput` because they are likely correlated
  * `FunctionInput` - just contains args and kwargs to pass to a function (is there a nicer way to do this?)
  * `modules` decorator - analogous to `ops`; specifies a set of modules to run a test over
  * Some constants used to keep track of all modules under torch.nn:
      * `MODULE_NAMESPACES` - list of all namespaces containing modules
      * `MODULE_CLASSES` - list of all module class objects
      * `MODULE_CLASS_NAMES` - dict from module class object to nice name (e.g. torch.nn.Linear -> "nn.Linear")
* (new file) `test/test_modules.py`
    * Uses the above to define tests over modules
    * Currently, there is one test for demonstration, `test_forward`, which instantiates a module, runs its forward pass, and compares it to a reference, if one is defined
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61935
Reviewed By: mruberry
Differential Revision: D29881832
Pulled By: jbschlosser
fbshipit-source-id: cc05c7d85f190a3aa42d55d4c8b01847d1efd57f
			
			
This commit is contained in:
		
				
					committed by
					
						
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							afe3644321
						
					
				
				
					commit
					a0309f89f4
				
			
							
								
								
									
										42
									
								
								test/test_modules.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								test/test_modules.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
 | 
			
		||||
from torch.testing._internal.common_modules import module_db, modules
 | 
			
		||||
from torch.testing._internal.common_utils import TestCase, run_tests, freeze_rng_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestModule(TestCase):
 | 
			
		||||
    _do_cuda_memory_leak_check = True
 | 
			
		||||
    _do_cuda_non_default_stream = True
 | 
			
		||||
    precision = 1e-5
 | 
			
		||||
    rel_tol = 1e-5
 | 
			
		||||
 | 
			
		||||
    @modules(module_db)
 | 
			
		||||
    def test_forward(self, device, dtype, module_info):
 | 
			
		||||
        module_cls = module_info.module_cls
 | 
			
		||||
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
 | 
			
		||||
                                                       requires_grad=False)
 | 
			
		||||
        for module_input in module_inputs:
 | 
			
		||||
            if module_input.forward_input is None:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            with freeze_rng_state():
 | 
			
		||||
                # === Instantiate the module. ===
 | 
			
		||||
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
 | 
			
		||||
                m = module_cls(*args, **kwargs)
 | 
			
		||||
                m.to(device).to(dtype)
 | 
			
		||||
 | 
			
		||||
                # === Do forward pass. ===
 | 
			
		||||
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
 | 
			
		||||
                outputs = m(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
                # === Compare outputs to a reference if one is specified. ===
 | 
			
		||||
                # TODO: Handle precision
 | 
			
		||||
                reference_fn = module_input.reference_fn
 | 
			
		||||
                if reference_fn is not None:
 | 
			
		||||
                    ref_outputs = reference_fn(m, *args, **kwargs)
 | 
			
		||||
                    self.assertEqual(outputs, ref_outputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
instantiate_device_type_tests(TestModule, globals())
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    run_tests()
 | 
			
		||||
		Reference in New Issue
	
	Block a user