mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fix ModuleInfo skip logic (#80471)
Fixes #80247 This PR: * Refactors the skip logic as done for OpInfo in #62713, fixing the logic error * For tests that were wrongly skipped before and now fail: * Fix `TestModule.test_cpu_gpu_parity` to support Lazy modules - this was affecting `LazyConv*` * Adds `@expectedFailure` decorators and a follow-up message to address `Conv*` failures on `TestModule.test_memory_format` Pull Request resolved: https://github.com/pytorch/pytorch/pull/80471 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
8fab682e47
commit
d3dba3c42a
@ -533,6 +533,12 @@ class TestModule(TestCase):
|
||||
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
|
||||
gpu_module.train(training)
|
||||
|
||||
# === Lazy modules need to see an input to initialize params ===
|
||||
if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin):
|
||||
with torch.no_grad():
|
||||
cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
|
||||
gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
|
||||
|
||||
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
|
||||
gpu_p.data.copy_(cpu_p)
|
||||
|
||||
|
@ -10,7 +10,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import TEST_CUDNN
|
||||
from torch.testing._internal.common_dtype import floating_types
|
||||
from torch.testing._internal.common_device_type import (
|
||||
_TestParametrizer, _update_param_kwargs, skipIf, toleranceOverride, tol,
|
||||
_TestParametrizer, _update_param_kwargs, toleranceOverride, tol,
|
||||
skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta)
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
||||
from torch.testing._internal.common_nn import nllloss_reference, get_reduction
|
||||
@ -104,25 +104,13 @@ class modules(_TestParametrizer):
|
||||
_update_param_kwargs(param_kwargs, 'training', training)
|
||||
|
||||
try:
|
||||
active_decorators = [set_single_threaded_if_parallel_tbb]
|
||||
if module_info.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype):
|
||||
active_decorators.append(skipIf(True, "Skipped!"))
|
||||
|
||||
if module_info.decorators is not None:
|
||||
for decorator in module_info.decorators:
|
||||
# Can't use isinstance as it would cause a circular import
|
||||
if decorator.__class__.__name__ == 'DecorateInfo':
|
||||
if decorator.is_active(generic_cls.__name__, test.__name__,
|
||||
device_cls.device_type, dtype):
|
||||
active_decorators += decorator.decorators
|
||||
else:
|
||||
active_decorators.append(decorator)
|
||||
|
||||
@wraps(test)
|
||||
def test_wrapper(*args, **kwargs):
|
||||
return test(*args, **kwargs)
|
||||
|
||||
for decorator in active_decorators:
|
||||
for decorator in module_info.get_decorators(generic_cls.__name__, test.__name__,
|
||||
device_cls.device_type, dtype):
|
||||
test_wrapper = decorator(test_wrapper)
|
||||
|
||||
yield (test_wrapper, test_name, param_kwargs)
|
||||
@ -187,16 +175,22 @@ class ModuleInfo(object):
|
||||
):
|
||||
self.module_cls = module_cls
|
||||
self.module_inputs_func = module_inputs_func
|
||||
self.skips = skips
|
||||
self.decorators = decorators
|
||||
self.decorators = (*(decorators if decorators else []), *(skips if skips else []))
|
||||
self.dtypes = dtypes
|
||||
self.supports_gradgrad = supports_gradgrad
|
||||
self.gradcheck_nondet_tol = gradcheck_nondet_tol
|
||||
self.module_memformat_affects_out = module_memformat_affects_out
|
||||
self.train_and_eval_differ = train_and_eval_differ
|
||||
|
||||
def should_skip(self, cls_name, test_name, device_type, dtype):
|
||||
return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
|
||||
def get_decorators(self, test_class, test_name, device, dtype):
|
||||
result = [set_single_threaded_if_parallel_tbb]
|
||||
for decorator in self.decorators:
|
||||
if isinstance(decorator, DecorateInfo):
|
||||
if decorator.is_active(test_class, test_name, device, dtype):
|
||||
result.extend(decorator.decorators)
|
||||
else:
|
||||
result.append(decorator)
|
||||
return result
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -1094,6 +1088,10 @@ module_db: List[ModuleInfo] = [
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
|
||||
device_type='cuda', dtypes=[torch.float64]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
@ -1108,6 +1106,9 @@ module_db: List[ModuleInfo] = [
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
@ -1136,6 +1137,11 @@ module_db: List[ModuleInfo] = [
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
|
||||
dtypes=[torch.float64]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
@ -1150,6 +1156,9 @@ module_db: List[ModuleInfo] = [
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
@ -1196,6 +1205,10 @@ module_db: List[ModuleInfo] = [
|
||||
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
|
||||
device_type='cuda', dtypes=[torch.float64]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
@ -1213,6 +1226,9 @@ module_db: List[ModuleInfo] = [
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
@ -1247,6 +1263,11 @@ module_db: List[ModuleInfo] = [
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
|
||||
dtypes=[torch.float64]),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
@ -1264,6 +1285,9 @@ module_db: List[ModuleInfo] = [
|
||||
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
|
||||
DecorateInfo(skipMeta),
|
||||
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
|
||||
# This was wrongly being skipped before and needs investigation.
|
||||
# See https://github.com/pytorch/pytorch/issues/80247
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
|
||||
),
|
||||
decorators=(
|
||||
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
|
||||
|
Reference in New Issue
Block a user