Fix ModuleInfo skip logic (#80471)

Fixes #80247

This PR:
* Refactors the skip logic as done for OpInfo in #62713, fixing the logic error
* For tests that were wrongly skipped before and now fail:
    * Fix `TestModule.test_cpu_gpu_parity` to support Lazy modules - this was affecting `LazyConv*`
    * Adds `@expectedFailure` decorators and a follow-up message to address `Conv*` failures on `TestModule.test_memory_format`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80471
Approved by: https://github.com/mruberry
This commit is contained in:
Joel Benjamin Schlosser
2022-07-08 11:38:57 -04:00
committed by PyTorch MergeBot
parent 8fab682e47
commit d3dba3c42a
2 changed files with 49 additions and 19 deletions

View File

@ -533,6 +533,12 @@ class TestModule(TestCase):
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
gpu_module.train(training)
# === Lazy modules need to see an input to initialize params ===
if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin):
with torch.no_grad():
cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
gpu_p.data.copy_(cpu_p)

View File

@ -10,7 +10,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_cuda import TEST_CUDNN
from torch.testing._internal.common_dtype import floating_types
from torch.testing._internal.common_device_type import (
_TestParametrizer, _update_param_kwargs, skipIf, toleranceOverride, tol,
_TestParametrizer, _update_param_kwargs, toleranceOverride, tol,
skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta)
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_nn import nllloss_reference, get_reduction
@ -104,25 +104,13 @@ class modules(_TestParametrizer):
_update_param_kwargs(param_kwargs, 'training', training)
try:
active_decorators = [set_single_threaded_if_parallel_tbb]
if module_info.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype):
active_decorators.append(skipIf(True, "Skipped!"))
if module_info.decorators is not None:
for decorator in module_info.decorators:
# Can't use isinstance as it would cause a circular import
if decorator.__class__.__name__ == 'DecorateInfo':
if decorator.is_active(generic_cls.__name__, test.__name__,
device_cls.device_type, dtype):
active_decorators += decorator.decorators
else:
active_decorators.append(decorator)
@wraps(test)
def test_wrapper(*args, **kwargs):
return test(*args, **kwargs)
for decorator in active_decorators:
for decorator in module_info.get_decorators(generic_cls.__name__, test.__name__,
device_cls.device_type, dtype):
test_wrapper = decorator(test_wrapper)
yield (test_wrapper, test_name, param_kwargs)
@ -187,16 +175,22 @@ class ModuleInfo(object):
):
self.module_cls = module_cls
self.module_inputs_func = module_inputs_func
self.skips = skips
self.decorators = decorators
self.decorators = (*(decorators if decorators else []), *(skips if skips else []))
self.dtypes = dtypes
self.supports_gradgrad = supports_gradgrad
self.gradcheck_nondet_tol = gradcheck_nondet_tol
self.module_memformat_affects_out = module_memformat_affects_out
self.train_and_eval_differ = train_and_eval_differ
def should_skip(self, cls_name, test_name, device_type, dtype):
return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
def get_decorators(self, test_class, test_name, device, dtype):
result = [set_single_threaded_if_parallel_tbb]
for decorator in self.decorators:
if isinstance(decorator, DecorateInfo):
if decorator.is_active(test_class, test_name, device, dtype):
result.extend(decorator.decorators)
else:
result.append(decorator)
return result
@property
def name(self):
@ -1094,6 +1088,10 @@ module_db: List[ModuleInfo] = [
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
device_type='cuda', dtypes=[torch.float64]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@ -1108,6 +1106,9 @@ module_db: List[ModuleInfo] = [
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@ -1136,6 +1137,11 @@ module_db: List[ModuleInfo] = [
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
dtypes=[torch.float64]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@ -1150,6 +1156,9 @@ module_db: List[ModuleInfo] = [
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@ -1196,6 +1205,10 @@ module_db: List[ModuleInfo] = [
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
device_type='cuda', dtypes=[torch.float64]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@ -1213,6 +1226,9 @@ module_db: List[ModuleInfo] = [
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@ -1247,6 +1263,11 @@ module_db: List[ModuleInfo] = [
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
dtypes=[torch.float64]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@ -1264,6 +1285,9 @@ module_db: List[ModuleInfo] = [
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),