mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
disable tf32 for rnn tests and norm tests (#102005)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102005 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea5eaa8692
commit
d805a53f1f
@ -8,6 +8,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import tf32_off
|
||||
from torch.testing._internal.common_utils import (
|
||||
is_iterable_of_tensors,
|
||||
TestCase,
|
||||
@ -433,6 +434,7 @@ class TestDecomp(TestCase):
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@suppress_warnings
|
||||
@tf32_off()
|
||||
# only tests RNNs since we have py dispsatcher decomps for them
|
||||
@modules(filter(lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
|
||||
def test_rnn_decomp_module(self, device, dtype, module_info, training):
|
||||
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
|
||||
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
|
||||
@ -305,6 +305,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
atol, rtol = self._compute_tolerances(device)
|
||||
return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
|
||||
|
||||
@tf32_off()
|
||||
def test_cnn_model_mean(self, device):
|
||||
def convnet(num_classes, num_dim):
|
||||
return nn.Sequential(
|
||||
@ -327,6 +328,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol)
|
||||
|
||||
@parametrize('num_dim', [1, 2, 3])
|
||||
@tf32_off()
|
||||
def test_instance_norm_model(self, num_dim, device):
|
||||
def instance_norm_model(num_classes, num_dim):
|
||||
conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
|
||||
@ -341,6 +343,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
return self._test_conv_model(instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol)
|
||||
|
||||
@parametrize('num_dim', [1, 2, 3])
|
||||
@tf32_off()
|
||||
def test_group_norm_model(self, num_dim, device):
|
||||
def group_norm_model(num_classes, num_dim):
|
||||
conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
|
||||
@ -354,6 +357,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
return self._test_conv_model(group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol)
|
||||
|
||||
@parametrize('num_dim', [1, 2, 3])
|
||||
@tf32_off()
|
||||
def test_layer_norm_model(self, num_dim, device):
|
||||
def layer_norm_model(num_classes, num_dim):
|
||||
conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
|
||||
@ -513,6 +517,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
|
||||
@tf32_off()
|
||||
def test_module(self, device, dtype, module_info, training):
|
||||
class RNNWrapper(torch.nn.Module):
|
||||
def __init__(self, m_cons, args, kwargs):
|
||||
|
Reference in New Issue
Block a user