mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: This PR contains the initial version of `ModuleInfo` for use in testing modules. The design philosophy taken here is to start small and simple and build out / refactor as needed when more test coverage or `ModuleInfo` entries are added. As such, it's not intended for general usage yet. The PR contains the following: * (new file) `torch/testing/_internal/common_modules.py` * `ModuleInfo` definition - metadata for each module to use in testing * `module_db` - the actual `ModuleInfo` database; currently contains entries for two modules * `ModuleInput` - analogous to `SampleInput` from OpInfo; contains `FunctionInput`s for both constructor and forward pass inputs * Constructor and forward pass inputs are tied together within a `ModuleInput` because they are likely correlated * `FunctionInput` - just contains args and kwargs to pass to a function (is there a nicer way to do this?) * `modules` decorator - analogous to `ops`; specifies a set of modules to run a test over * Some constants used to keep track of all modules under torch.nn: * `MODULE_NAMESPACES` - list of all namespaces containing modules * `MODULE_CLASSES` - list of all module class objects * `MODULE_CLASS_NAMES` - dict from module class object to nice name (e.g. torch.nn.Linear -> "nn.Linear") * (new file) `test/test_modules.py` * Uses the above to define tests over modules * Currently, there is one test for demonstration, `test_forward`, which instantiates a module, runs its forward pass, and compares it to a reference, if one is defined Pull Request resolved: https://github.com/pytorch/pytorch/pull/61935 Reviewed By: mruberry Differential Revision: D29881832 Pulled By: jbschlosser fbshipit-source-id: cc05c7d85f190a3aa42d55d4c8b01847d1efd57f
43 lines
1.7 KiB
Python
43 lines
1.7 KiB
Python
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_modules import module_db, modules
|
|
from torch.testing._internal.common_utils import TestCase, run_tests, freeze_rng_state
|
|
|
|
|
|
class TestModule(TestCase):
|
|
_do_cuda_memory_leak_check = True
|
|
_do_cuda_non_default_stream = True
|
|
precision = 1e-5
|
|
rel_tol = 1e-5
|
|
|
|
@modules(module_db)
|
|
def test_forward(self, device, dtype, module_info):
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=False)
|
|
for module_input in module_inputs:
|
|
if module_input.forward_input is None:
|
|
continue
|
|
|
|
with freeze_rng_state():
|
|
# === Instantiate the module. ===
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
m = module_cls(*args, **kwargs)
|
|
m.to(device).to(dtype)
|
|
|
|
# === Do forward pass. ===
|
|
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
|
outputs = m(*args, **kwargs)
|
|
|
|
# === Compare outputs to a reference if one is specified. ===
|
|
# TODO: Handle precision
|
|
reference_fn = module_input.reference_fn
|
|
if reference_fn is not None:
|
|
ref_outputs = reference_fn(m, *args, **kwargs)
|
|
self.assertEqual(outputs, ref_outputs)
|
|
|
|
|
|
instantiate_device_type_tests(TestModule, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|