mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Support torch.{cuda/cpu}.amp.autocast (#95416)
For Meta internal use cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95416 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d05f2ae476
commit
7fcf8b1829
@ -7,12 +7,13 @@ from typing import Optional, Tuple
|
||||
import unittest
|
||||
from test_jit import JitTestCase
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
|
||||
from torch.testing import FileCheck
|
||||
from jit.test_models import MnistNet
|
||||
|
||||
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
class TestAutocast(JitTestCase):
|
||||
def setUp(self):
|
||||
# common input tensors
|
||||
@ -757,6 +758,7 @@ class convbn(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.bn(self.conv(x))
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
class TestJitTraceAutocast(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
Reference in New Issue
Block a user