Adjust TF32 tests (#44240)

Summary:
- The thresholds of some tests are bumped up. Depending on the random generator, sometimes these tests fail with things like 0.0059 is not smaller than 0.005. I ran `test_nn.py` and `test_torch.py` for 10+ times to check these are no longer flaky.
- Add `tf32_on_and_off` to new `matrix_exp` tests.
- Disable TF32 on test suites other than `test_nn.py` and `test_torch.py`

cc: ptrblck

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44240

Reviewed By: mruberry

Differential Revision: D23882498

Pulled By: ngimel

fbshipit-source-id: 44a9ec08802c93a2efaf4e01d7487222478b6df8
This commit is contained in:
Gao, Xiang
2020-09-24 10:23:46 -07:00
committed by Facebook GitHub Bot
parent b8eab8cdbd
commit 3f5eee666c
11 changed files with 124 additions and 28 deletions

View File

@ -10,6 +10,7 @@ from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _in
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward
from textwrap import dedent
from itertools import product, permutations
from torch.testing._internal.common_cuda import with_tf32_off
from test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
@ -710,6 +711,9 @@ class TestFuser(JitTestCase):
"aten::_grad_sum_to_size"))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
# By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
# We want float tensors to be computed at full precision in order to use the default precision
@with_tf32_off
def test_lstm_concat_cuda(self):
inputs = get_lstm_inputs('cuda')
ge = self.checkTrace(LSTMCellC, inputs)
@ -740,6 +744,9 @@ class TestFuser(JitTestCase):
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
# By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
# We want float tensors to be computed at full precision in order to use the default precision
@with_tf32_off
def test_lstm_traced_cuda(self):
inputs = get_lstm_inputs('cuda')
ge = self.checkTrace(LSTMCellF, inputs)