mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Support torch.{cuda/cpu}.amp.autocast (#95416)
For Meta internal use cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95416 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8f7bd593c
commit
c88aa336aa
@ -247,3 +247,4 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
|
||||
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
cond.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
cond.fallthrough(DispatchKey.BackendSelect)
|
||||
cond.fallthrough(DispatchKey.AutocastCPU)
|
||||
|
@ -133,3 +133,4 @@ def map_functionalize(interpreter, f, xs, *args):
|
||||
map.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
map.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
map.fallthrough(DispatchKey.BackendSelect)
|
||||
map.fallthrough(DispatchKey.AutocastCPU)
|
||||
|
@ -26,7 +26,6 @@ test_classes = {}
|
||||
|
||||
ALL_DYNAMIC_XFAILS = {
|
||||
"MiscTests": [
|
||||
"test_autocast_sdpa",
|
||||
"test_parsing_sdpa",
|
||||
],
|
||||
"ReproTests": [
|
||||
|
@ -3287,10 +3287,51 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(exported.device.index, 0)
|
||||
self.assertEqual(exported.dtype, torch.bfloat16)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_cuda_amp_autocast(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
a_float32 = torch.rand((8, 8), device="cuda")
|
||||
b_float32 = torch.rand((8, 8), device="cuda")
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=torch.torch.float64):
|
||||
c_float64 = torch.mm(a_float32, b_float32)
|
||||
return c_float64
|
||||
|
||||
module = MyModule()
|
||||
real = module(torch.tensor([0.5]))
|
||||
real_device = real.device
|
||||
real_dtype = real.dtype
|
||||
|
||||
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
|
||||
exported = graph(torch.tensor([0.5]))
|
||||
self.assertEqual(exported.device, real_device)
|
||||
self.assertEqual(exported.dtype, real_dtype)
|
||||
|
||||
self.assertEqual(exported.device.type, "cuda")
|
||||
self.assertEqual(exported.device.index, 0)
|
||||
self.assertEqual(exported.dtype, torch.float64)
|
||||
|
||||
def test_is_autocast_cpu_enabled(self):
|
||||
def fn(a_float32, b_float32):
|
||||
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
|
||||
c_float16 = torch.mm(a_float32, b_float32)
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
c_float16 = c_float16 + 1
|
||||
return c_float16
|
||||
|
||||
a = torch.rand((8, 8))
|
||||
b = torch.rand((8, 8))
|
||||
ref = fn(a, b)
|
||||
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
||||
res = opt_fn(a, b)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
|
||||
"Can't run fused SDPA on this platform",
|
||||
)
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", False)
|
||||
def test_autocast_sdpa(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, query, key, value):
|
||||
|
@ -7,12 +7,13 @@ from typing import Optional, Tuple
|
||||
import unittest
|
||||
from test_jit import JitTestCase
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
|
||||
from torch.testing import FileCheck
|
||||
from jit.test_models import MnistNet
|
||||
|
||||
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
class TestAutocast(JitTestCase):
|
||||
def setUp(self):
|
||||
# common input tensors
|
||||
@ -757,6 +758,7 @@ class convbn(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.bn(self.conv(x))
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
class TestJitTraceAutocast(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -96,8 +96,6 @@ def _disallowed_function_ids():
|
||||
torch.autograd.grad,
|
||||
torch.clear_autocast_cache,
|
||||
torch.cuda.current_device,
|
||||
torch.cuda.amp.autocast_mode.autocast,
|
||||
torch.cpu.amp.autocast_mode.autocast,
|
||||
torch.distributions.constraints.is_dependent,
|
||||
torch.distributions.normal.Normal,
|
||||
torch.inference_mode,
|
||||
|
@ -280,13 +280,19 @@ class AutocastModeVariable(ContextWrappingVariable):
|
||||
self.mode = mode
|
||||
|
||||
def exit(self, tx, *args):
|
||||
self.mode = tx.output.create_node(
|
||||
"call_function", exit_functional_autocast, (self.mode,), {}
|
||||
self.mode = (
|
||||
exit_functional_autocast(self.mode[0]),
|
||||
tx.output.create_node(
|
||||
"call_function", exit_functional_autocast, (self.mode[1],), {}
|
||||
),
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
self.mode = tx.output.create_node(
|
||||
"call_function", enter_functional_autocast, (*self.target_values,), {}
|
||||
self.mode = (
|
||||
enter_functional_autocast(*self.target_values),
|
||||
tx.output.create_node(
|
||||
"call_function", enter_functional_autocast, (*self.target_values,), {}
|
||||
),
|
||||
)
|
||||
|
||||
def module_name(self):
|
||||
|
@ -64,6 +64,9 @@ constant_fold_functions = [
|
||||
torch.finfo,
|
||||
torch.get_default_dtype,
|
||||
torch.iinfo,
|
||||
torch.is_autocast_cache_enabled,
|
||||
torch.is_autocast_cpu_enabled,
|
||||
torch.is_autocast_enabled,
|
||||
torch.is_floating_point,
|
||||
torch.nn.functional._Reduction.get_enum,
|
||||
]
|
||||
@ -324,6 +327,13 @@ class TorchVariable(VariableTracker):
|
||||
)
|
||||
elif self.value is torch.amp.autocast_mode.autocast:
|
||||
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
|
||||
elif self.value in [torch.cuda.amp.autocast, torch.cpu.amp.autocast]:
|
||||
assert "device_type" not in kwargs
|
||||
if self.value is torch.cuda.amp.autocast:
|
||||
kwargs.update({"device_type": ConstantVariable("cuda")})
|
||||
else:
|
||||
kwargs.update({"device_type": ConstantVariable("cpu")})
|
||||
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
|
||||
elif self.value in (
|
||||
torch.profiler.profile,
|
||||
torch.profiler.record_function,
|
||||
|
Reference in New Issue
Block a user