TST Adds inplace checks to module_info (#63739)

Summary:
Follow up to https://github.com/pytorch/pytorch/pull/61935

This PR adds inplace checks to `test_modules`. This version checks the constructor for `inplace` and performs the check automatically.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63739

Reviewed By: saketh-are

Differential Revision: D30737774

Pulled By: jbschlosser

fbshipit-source-id: 8813534511e9296c8424d1ca878412726ddd4043
This commit is contained in:
Thomas J. Fan
2021-09-08 11:00:11 -07:00
committed by Facebook GitHub Bot
parent a5ad08ec70
commit 43c0f033fc
2 changed files with 66 additions and 1 deletions

View File

@ -1,3 +1,5 @@
from inspect import signature
from copy import deepcopy
import tempfile
import torch
@ -154,6 +156,55 @@ class TestModule(TestCase):
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)
@modules([module_info for module_info in module_db
if 'inplace' in signature(module_info.module_cls).parameters])
def test_check_inplace(self, device, dtype, module_info):
# Check if the inplace variant of the module gives the same result as the out of place
# variant.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
# === Instantiate the module. ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m_op = module_cls(*args, **kwargs, inplace=False)
m_op.to(device).to(dtype)
m_inplace = module_cls(*args, **kwargs, inplace=True)
m_inplace.to(device).to(dtype)
# === Inplace modules only supports inplace operations on the first argument ===
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
# === Do not allow the first input to be in input_kwargs ===
forward_sig = signature(m_op).parameters
self.assertGreaterEqual(len(forward_sig), 1)
first_param_name = next(iter(forward_sig.items()))
self.assertNotIn(first_param_name, input_kwargs)
# === Out of place operation does not write to original tensor ===
self.assertGreaterEqual(len(input_args), 1)
input_version = input_args[0]._version
with freeze_rng_state():
output_op = m_op(*input_args, **input_kwargs)
self.assertEqual(input_args[0]._version, input_version)
# === Check that the inplace operation gives the same result ===
input_arg_copy = deepcopy(input_args)
input_arg_clone = tuple(i.clone() for i in input_arg_copy)
with freeze_rng_state():
output_ip = m_inplace(*input_arg_clone, **input_kwargs)
self.assertNotEqual(input_arg_clone[0]._version, input_version)
self.assertEqual(output_op, output_ip)
# === Check that the gradients are the same ===
grad = output_op.data.clone().normal_()
output_op.backward(grad)
output_ip.backward(grad)
self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
instantiate_device_type_tests(TestModule, globals())