mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support both train / eval modes for ModuleInfo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78735 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
79f18c1aee
commit
70d6446a3d
@ -9,7 +9,7 @@ from operator import methodcaller
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps)
|
||||
from unittest.mock import patch, call
|
||||
@ -42,10 +42,10 @@ class TestModule(TestCase):
|
||||
|
||||
@skipIfMps # the test doesn't work on MPS as double types are not supported
|
||||
@modules(module_db)
|
||||
def test_forward(self, device, dtype, module_info):
|
||||
def test_forward(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
dtype_to_method_caller = {
|
||||
torch.float32: methodcaller("float"),
|
||||
torch.float64: methodcaller("double"),
|
||||
@ -59,6 +59,7 @@ class TestModule(TestCase):
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Do forward pass. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
@ -80,10 +81,10 @@ class TestModule(TestCase):
|
||||
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
|
||||
# They should be applied to any created parameters and buffers.
|
||||
@modules(module_db)
|
||||
def test_factory_kwargs(self, device, dtype, module_info):
|
||||
def test_factory_kwargs(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input in module_inputs:
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
|
||||
@ -96,6 +97,7 @@ class TestModule(TestCase):
|
||||
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
|
||||
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.train(training)
|
||||
|
||||
# Check if a parameter or buffer was created with a tensor not passed to the constructor.
|
||||
constructor_tensors = get_tensors_from(args, kwargs)
|
||||
@ -122,6 +124,7 @@ class TestModule(TestCase):
|
||||
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
|
||||
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.train(training)
|
||||
uninit_param_new.mock.assert_has_calls(
|
||||
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
|
||||
uninit_buffer_new.mock.assert_has_calls(
|
||||
@ -130,16 +133,17 @@ class TestModule(TestCase):
|
||||
# Check device placement and dtype for created parameters and buffers.
|
||||
# Only verify floating point dtypes since that's what the kwarg applies to.
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.train(training)
|
||||
self._assert_module_parameters_and_buffer_are(m, device, dtype)
|
||||
|
||||
@onlyCUDA
|
||||
@modules(module_db)
|
||||
def test_multiple_device_transfer(self, device, dtype, module_info):
|
||||
def test_multiple_device_transfer(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
|
||||
if module_input_device.forward_input is None:
|
||||
continue
|
||||
@ -149,6 +153,7 @@ class TestModule(TestCase):
|
||||
args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Do forward pass on GPU ===
|
||||
input_device_args = module_input_device.forward_input.args
|
||||
@ -189,14 +194,16 @@ class TestModule(TestCase):
|
||||
|
||||
|
||||
@modules(module_db)
|
||||
def test_repr(self, device, dtype, module_info):
|
||||
def test_repr(self, device, dtype, module_info, training):
|
||||
# Test module can be represented with repr and str without errors.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input in module_inputs:
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# Check that these methods do not raise errors
|
||||
m.__repr__()
|
||||
@ -204,11 +211,11 @@ class TestModule(TestCase):
|
||||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
def test_pickle(self, device, dtype, module_info):
|
||||
def test_pickle(self, device, dtype, module_info, training):
|
||||
# Test that module can be pickled and unpickled.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
@ -220,6 +227,7 @@ class TestModule(TestCase):
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Do forward pass. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
@ -233,15 +241,15 @@ class TestModule(TestCase):
|
||||
output_from_copy = m_copy(*args, **kwargs)
|
||||
self.assertEqual(output, output_from_copy)
|
||||
|
||||
@skipMeta
|
||||
@modules([module_info for module_info in module_db
|
||||
if 'inplace' in signature(module_info.module_cls).parameters])
|
||||
@skipMeta
|
||||
def test_check_inplace(self, device, dtype, module_info):
|
||||
def test_check_inplace(self, device, dtype, module_info, training):
|
||||
# Check if the inplace variant of the module gives the same result as the out of place
|
||||
# variant.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
@ -250,8 +258,10 @@ class TestModule(TestCase):
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m_op = module_cls(*args, **kwargs, inplace=False)
|
||||
m_op.to(device).to(dtype)
|
||||
m_op.train(training)
|
||||
m_inplace = module_cls(*args, **kwargs, inplace=True)
|
||||
m_inplace.to(device).to(dtype)
|
||||
m_inplace.train(training)
|
||||
|
||||
# === Inplace modules only supports inplace operations on the first argument ===
|
||||
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
@ -315,12 +325,12 @@ class TestModule(TestCase):
|
||||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
def test_non_contiguous_tensors(self, device, dtype, module_info):
|
||||
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
|
||||
# Check modules work with non-contiguous tensors
|
||||
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
|
||||
def _make_non_contiguous(obj):
|
||||
def inner_make_non_contiguous(obj):
|
||||
@ -357,6 +367,7 @@ class TestModule(TestCase):
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
self._retain_grad((input_args, input_kwargs))
|
||||
|
||||
@ -409,11 +420,11 @@ class TestModule(TestCase):
|
||||
self.assertEqual(param_grad, default_param_grad)
|
||||
|
||||
|
||||
def _test_gradients_helper(self, device, dtype, module_info, check):
|
||||
def _test_gradients_helper(self, device, dtype, module_info, training, check):
|
||||
# Check gradients
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
# === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
|
||||
gradcheck_nondet_tol = 0.0
|
||||
if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
|
||||
@ -427,6 +438,7 @@ class TestModule(TestCase):
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
params = tuple(m.parameters())
|
||||
|
||||
@ -464,23 +476,33 @@ class TestModule(TestCase):
|
||||
|
||||
|
||||
@modules(module_db, allowed_dtypes=[torch.double])
|
||||
def test_grad(self, device, dtype, module_info):
|
||||
self._test_gradients_helper(device, dtype, module_info, gradcheck)
|
||||
def test_grad(self, device, dtype, module_info, training):
|
||||
self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
|
||||
|
||||
@modules([m for m in module_db if m.supports_gradgrad],
|
||||
allowed_dtypes=[torch.double])
|
||||
def test_gradgrad(self, device, dtype, module_info):
|
||||
self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
|
||||
def test_gradgrad(self, device, dtype, module_info, training):
|
||||
self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
|
||||
|
||||
@onlyCUDA
|
||||
@toleranceOverride({torch.float32: tol(5e-2, 0),
|
||||
torch.float64: tol(4e-4, 0)})
|
||||
@modules(module_db)
|
||||
def test_cpu_gpu_parity(self, device, dtype, module_info):
|
||||
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
|
||||
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
|
||||
# nicer way for eval mode only.
|
||||
# See https://github.com/pytorch/pytorch/issues/79161
|
||||
rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM])
|
||||
if (module_info.module_cls in rnn_modules
|
||||
and not training
|
||||
and 'cuda' in device
|
||||
and torch.backends.cudnn.enabled):
|
||||
return
|
||||
|
||||
# Test cpu and gpu results are the same
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
|
||||
requires_grad=True)
|
||||
requires_grad=True, training=training)
|
||||
|
||||
def _to_device(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
@ -495,7 +517,6 @@ class TestModule(TestCase):
|
||||
return deepcopy(obj)
|
||||
|
||||
for module_input in module_inputs_cpu:
|
||||
|
||||
# === Move input from cpu to device ===
|
||||
cpu_forward_args = module_input.forward_input.args
|
||||
cpu_forward_kwargs = module_input.forward_input.kwargs
|
||||
@ -508,7 +529,9 @@ class TestModule(TestCase):
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
|
||||
cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
|
||||
cpu_module.train(training)
|
||||
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
|
||||
gpu_module.train(training)
|
||||
|
||||
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
|
||||
gpu_p.data.copy_(cpu_p)
|
||||
@ -549,10 +572,10 @@ class TestModule(TestCase):
|
||||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
def test_memory_format(self, device, dtype, module_info):
|
||||
def test_memory_format(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
requires_grad=False, training=training)
|
||||
module_memformat_affects_out = module_info.module_memformat_affects_out
|
||||
|
||||
def _get_mem_formats(channels_last=False, channels_last_3d=False):
|
||||
@ -613,6 +636,7 @@ class TestModule(TestCase):
|
||||
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# === Get output in (contiguous, contiguous) configuration. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
@ -640,6 +664,42 @@ class TestModule(TestCase):
|
||||
# === Check mem format of output. ===
|
||||
_check_out_mem_format(outputs, input_mem_format, module_mem_format)
|
||||
|
||||
# Test whether train and eval modes differ for each module. Use to verify
|
||||
# that the ModuleInfo entry flag is correct.
|
||||
@skipIfMps # the test doesn't work on MPS as double types are not supported
|
||||
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
|
||||
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False, training=training)
|
||||
|
||||
# Run forward inputs through to see if the training flag is accessed during forward.
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
|
||||
# === Instantiate the module. ===
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
|
||||
# Remove training attribute and see if forward still works.
|
||||
delattr(m, 'training')
|
||||
|
||||
# === Do forward pass. ===
|
||||
try:
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
m(*args, **kwargs)
|
||||
except AttributeError as e:
|
||||
if "'training'" in str(e):
|
||||
self.assertTrue(module_info.train_and_eval_differ,
|
||||
f"The ModuleInfo entry for {module_info.name} has "
|
||||
"train_and_eval_differ=False, but the training mode was found to "
|
||||
"affect the forward pass. Consider setting train_and_eval_differ=True "
|
||||
"for this ModuleInfo entry.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user