disable tf32 for rnn tests and norm tests (#102005)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102005
Approved by: https://github.com/ngimel
This commit is contained in:
Fuzzkatt
2023-05-24 02:22:53 +00:00
committed by PyTorch MergeBot
parent ea5eaa8692
commit d805a53f1f
2 changed files with 8 additions and 1 deletions

View File

@ -9,7 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
@ -305,6 +305,7 @@ class TestExpandedWeightFunctional(TestCase):
atol, rtol = self._compute_tolerances(device)
return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
@tf32_off()
def test_cnn_model_mean(self, device):
def convnet(num_classes, num_dim):
return nn.Sequential(
@ -327,6 +328,7 @@ class TestExpandedWeightFunctional(TestCase):
return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol)
@parametrize('num_dim', [1, 2, 3])
@tf32_off()
def test_instance_norm_model(self, num_dim, device):
def instance_norm_model(num_classes, num_dim):
conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
@ -341,6 +343,7 @@ class TestExpandedWeightFunctional(TestCase):
return self._test_conv_model(instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol)
@parametrize('num_dim', [1, 2, 3])
@tf32_off()
def test_group_norm_model(self, num_dim, device):
def group_norm_model(num_classes, num_dim):
conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
@ -354,6 +357,7 @@ class TestExpandedWeightFunctional(TestCase):
return self._test_conv_model(group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol)
@parametrize('num_dim', [1, 2, 3])
@tf32_off()
def test_layer_norm_model(self, num_dim, device):
def layer_norm_model(num_classes, num_dim):
conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
@ -513,6 +517,7 @@ class TestExpandedWeightModule(TestCase):
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
@tf32_off()
def test_module(self, device, dtype, module_info, training):
class RNNWrapper(torch.nn.Module):
def __init__(self, m_cons, args, kwargs):