[Dynamo] Support torch.{cuda/cpu}.amp.autocast (#95416)

For Meta internal use cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95416
Approved by: https://github.com/jansel
This commit is contained in:
Yanbo Liang
2023-03-08 01:40:27 +00:00
committed by PyTorch MergeBot
parent b8f7bd593c
commit c88aa336aa
8 changed files with 66 additions and 8 deletions

View File

@ -247,3 +247,4 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
cond.fallthrough(DispatchKey.ADInplaceOrView)
cond.fallthrough(DispatchKey.BackendSelect)
cond.fallthrough(DispatchKey.AutocastCPU)

View File

@ -133,3 +133,4 @@ def map_functionalize(interpreter, f, xs, *args):
map.fallthrough(DispatchKey.PythonTLSSnapshot)
map.fallthrough(DispatchKey.ADInplaceOrView)
map.fallthrough(DispatchKey.BackendSelect)
map.fallthrough(DispatchKey.AutocastCPU)

View File

@ -26,7 +26,6 @@ test_classes = {}
ALL_DYNAMIC_XFAILS = {
"MiscTests": [
"test_autocast_sdpa",
"test_parsing_sdpa",
],
"ReproTests": [

View File

@ -3287,10 +3287,51 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.bfloat16)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_amp_autocast(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
with torch.cuda.amp.autocast(dtype=torch.torch.float64):
c_float64 = torch.mm(a_float32, b_float32)
return c_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cuda")
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.float64)
def test_is_autocast_cpu_enabled(self):
def fn(a_float32, b_float32):
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
c_float16 = torch.mm(a_float32, b_float32)
if torch.is_autocast_cpu_enabled():
c_float16 = c_float16 + 1
return c_float16
a = torch.rand((8, 8))
b = torch.rand((8, 8))
ref = fn(a, b)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
res = opt_fn(a, b)
self.assertTrue(same(ref, res))
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
"Can't run fused SDPA on this platform",
)
@patch.object(torch._dynamo.config, "dynamic_shapes", False)
def test_autocast_sdpa(self):
class MyModule(torch.nn.Module):
def forward(self, query, key, value):

View File

@ -7,12 +7,13 @@ from typing import Optional, Tuple
import unittest
from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
from torch.testing import FileCheck
from jit.test_models import MnistNet
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
class TestAutocast(JitTestCase):
def setUp(self):
# common input tensors
@ -757,6 +758,7 @@ class convbn(torch.nn.Module):
def forward(self, x):
return self.bn(self.conv(x))
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
class TestJitTraceAutocast(JitTestCase):
def setUp(self):
super().setUp()

View File

@ -96,8 +96,6 @@ def _disallowed_function_ids():
torch.autograd.grad,
torch.clear_autocast_cache,
torch.cuda.current_device,
torch.cuda.amp.autocast_mode.autocast,
torch.cpu.amp.autocast_mode.autocast,
torch.distributions.constraints.is_dependent,
torch.distributions.normal.Normal,
torch.inference_mode,

View File

@ -280,13 +280,19 @@ class AutocastModeVariable(ContextWrappingVariable):
self.mode = mode
def exit(self, tx, *args):
self.mode = tx.output.create_node(
"call_function", exit_functional_autocast, (self.mode,), {}
self.mode = (
exit_functional_autocast(self.mode[0]),
tx.output.create_node(
"call_function", exit_functional_autocast, (self.mode[1],), {}
),
)
def enter(self, tx):
self.mode = tx.output.create_node(
"call_function", enter_functional_autocast, (*self.target_values,), {}
self.mode = (
enter_functional_autocast(*self.target_values),
tx.output.create_node(
"call_function", enter_functional_autocast, (*self.target_values,), {}
),
)
def module_name(self):

View File

@ -64,6 +64,9 @@ constant_fold_functions = [
torch.finfo,
torch.get_default_dtype,
torch.iinfo,
torch.is_autocast_cache_enabled,
torch.is_autocast_cpu_enabled,
torch.is_autocast_enabled,
torch.is_floating_point,
torch.nn.functional._Reduction.get_enum,
]
@ -324,6 +327,13 @@ class TorchVariable(VariableTracker):
)
elif self.value is torch.amp.autocast_mode.autocast:
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
elif self.value in [torch.cuda.amp.autocast, torch.cpu.amp.autocast]:
assert "device_type" not in kwargs
if self.value is torch.cuda.amp.autocast:
kwargs.update({"device_type": ConstantVariable("cuda")})
else:
kwargs.update({"device_type": ConstantVariable("cpu")})
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
elif self.value in (
torch.profiler.profile,
torch.profiler.record_function,