Support both train / eval modes for ModuleInfo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78735

Approved by: https://github.com/albanD
This commit is contained in:
Joel Benjamin Schlosser
2022-06-09 12:37:29 -04:00
committed by PyTorch MergeBot
parent 79f18c1aee
commit 70d6446a3d
3 changed files with 159 additions and 64 deletions

View File

@ -9,7 +9,7 @@ from operator import methodcaller
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps)
from unittest.mock import patch, call
@ -42,10 +42,10 @@ class TestModule(TestCase):
@skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db)
def test_forward(self, device, dtype, module_info):
def test_forward(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
dtype_to_method_caller = {
torch.float32: methodcaller("float"),
torch.float64: methodcaller("double"),
@ -59,6 +59,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -80,10 +81,10 @@ class TestModule(TestCase):
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
# They should be applied to any created parameters and buffers.
@modules(module_db)
def test_factory_kwargs(self, device, dtype, module_info):
def test_factory_kwargs(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input in module_inputs:
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
@ -96,6 +97,7 @@ class TestModule(TestCase):
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
m = module_cls(*args, **kwargs)
m.train(training)
# Check if a parameter or buffer was created with a tensor not passed to the constructor.
constructor_tensors = get_tensors_from(args, kwargs)
@ -122,6 +124,7 @@ class TestModule(TestCase):
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
m = module_cls(*args, **kwargs)
m.train(training)
uninit_param_new.mock.assert_has_calls(
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
uninit_buffer_new.mock.assert_has_calls(
@ -130,16 +133,17 @@ class TestModule(TestCase):
# Check device placement and dtype for created parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to.
m = module_cls(*args, **kwargs)
m.train(training)
self._assert_module_parameters_and_buffer_are(m, device, dtype)
@onlyCUDA
@modules(module_db)
def test_multiple_device_transfer(self, device, dtype, module_info):
def test_multiple_device_transfer(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
if module_input_device.forward_input is None:
continue
@ -149,6 +153,7 @@ class TestModule(TestCase):
args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Do forward pass on GPU ===
input_device_args = module_input_device.forward_input.args
@ -189,14 +194,16 @@ class TestModule(TestCase):
@modules(module_db)
def test_repr(self, device, dtype, module_info):
def test_repr(self, device, dtype, module_info, training):
# Test module can be represented with repr and str without errors.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input in module_inputs:
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# Check that these methods do not raise errors
m.__repr__()
@ -204,11 +211,11 @@ class TestModule(TestCase):
@skipIfMps
@modules(module_db)
def test_pickle(self, device, dtype, module_info):
def test_pickle(self, device, dtype, module_info, training):
# Test that module can be pickled and unpickled.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
@ -220,6 +227,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -233,15 +241,15 @@ class TestModule(TestCase):
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)
@skipMeta
@modules([module_info for module_info in module_db
if 'inplace' in signature(module_info.module_cls).parameters])
@skipMeta
def test_check_inplace(self, device, dtype, module_info):
def test_check_inplace(self, device, dtype, module_info, training):
# Check if the inplace variant of the module gives the same result as the out of place
# variant.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
@ -250,8 +258,10 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m_op = module_cls(*args, **kwargs, inplace=False)
m_op.to(device).to(dtype)
m_op.train(training)
m_inplace = module_cls(*args, **kwargs, inplace=True)
m_inplace.to(device).to(dtype)
m_inplace.train(training)
# === Inplace modules only supports inplace operations on the first argument ===
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -315,12 +325,12 @@ class TestModule(TestCase):
@skipIfMps
@modules(module_db)
def test_non_contiguous_tensors(self, device, dtype, module_info):
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
# Check modules work with non-contiguous tensors
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
def _make_non_contiguous(obj):
def inner_make_non_contiguous(obj):
@ -357,6 +367,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
self._retain_grad((input_args, input_kwargs))
@ -409,11 +420,11 @@ class TestModule(TestCase):
self.assertEqual(param_grad, default_param_grad)
def _test_gradients_helper(self, device, dtype, module_info, check):
def _test_gradients_helper(self, device, dtype, module_info, training, check):
# Check gradients
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
# === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
gradcheck_nondet_tol = 0.0
if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
@ -427,6 +438,7 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
params = tuple(m.parameters())
@ -464,23 +476,33 @@ class TestModule(TestCase):
@modules(module_db, allowed_dtypes=[torch.double])
def test_grad(self, device, dtype, module_info):
self._test_gradients_helper(device, dtype, module_info, gradcheck)
def test_grad(self, device, dtype, module_info, training):
self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
@modules([m for m in module_db if m.supports_gradgrad],
allowed_dtypes=[torch.double])
def test_gradgrad(self, device, dtype, module_info):
self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
def test_gradgrad(self, device, dtype, module_info, training):
self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
@onlyCUDA
@toleranceOverride({torch.float32: tol(5e-2, 0),
torch.float64: tol(4e-4, 0)})
@modules(module_db)
def test_cpu_gpu_parity(self, device, dtype, module_info):
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
# nicer way for eval mode only.
# See https://github.com/pytorch/pytorch/issues/79161
rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM])
if (module_info.module_cls in rnn_modules
and not training
and 'cuda' in device
and torch.backends.cudnn.enabled):
return
# Test cpu and gpu results are the same
module_cls = module_info.module_cls
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
requires_grad=True)
requires_grad=True, training=training)
def _to_device(obj):
if isinstance(obj, torch.Tensor):
@ -495,7 +517,6 @@ class TestModule(TestCase):
return deepcopy(obj)
for module_input in module_inputs_cpu:
# === Move input from cpu to device ===
cpu_forward_args = module_input.forward_input.args
cpu_forward_kwargs = module_input.forward_input.kwargs
@ -508,7 +529,9 @@ class TestModule(TestCase):
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
cpu_module.train(training)
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
gpu_module.train(training)
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
gpu_p.data.copy_(cpu_p)
@ -549,10 +572,10 @@ class TestModule(TestCase):
@skipIfMps
@modules(module_db)
def test_memory_format(self, device, dtype, module_info):
def test_memory_format(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
requires_grad=False, training=training)
module_memformat_affects_out = module_info.module_memformat_affects_out
def _get_mem_formats(channels_last=False, channels_last_3d=False):
@ -613,6 +636,7 @@ class TestModule(TestCase):
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# === Get output in (contiguous, contiguous) configuration. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
@ -640,6 +664,42 @@ class TestModule(TestCase):
# === Check mem format of output. ===
_check_out_mem_format(outputs, input_mem_format, module_mem_format)
# Test whether train and eval modes differ for each module. Use to verify
# that the ModuleInfo entry flag is correct.
@skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False, training=training)
# Run forward inputs through to see if the training flag is accessed during forward.
for module_input in module_inputs:
if module_input.forward_input is None:
continue
# === Instantiate the module. ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
# Remove training attribute and see if forward still works.
delattr(m, 'training')
# === Do forward pass. ===
try:
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
m(*args, **kwargs)
except AttributeError as e:
if "'training'" in str(e):
self.assertTrue(module_info.train_and_eval_differ,
f"The ModuleInfo entry for {module_info.name} has "
"train_and_eval_differ=False, but the training mode was found to "
"affect the forward pass. Consider setting train_and_eval_differ=True "
"for this ModuleInfo entry.")
else:
raise e
instantiate_device_type_tests(TestModule, globals())