Revert "[JIT] Support scripting torch.is_autocast_enabled() (#81305)"

This reverts commit bcc9084bc444bb68f038c85d9f2b84de42971b58.

Reverted https://github.com/pytorch/pytorch/pull/81305 on behalf of https://github.com/malfet due to Broke lite-intepreter builds, see https://github.com/pytorch/pytorch/runs/7550084494?check_suite_focus=true
This commit is contained in:
PyTorch MergeBot
2022-07-28 00:02:53 +00:00
parent 0e95746580
commit 554b4060aa
5 changed files with 2 additions and 160 deletions

View File

@ -2,7 +2,7 @@
import torch
from torch.cuda.amp import autocast
from typing import Optional, Tuple
from typing import Optional
import unittest
from test_jit import JitTestCase
@ -819,101 +819,5 @@ class TestJitTraceAutocast(JitTestCase):
continue
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_script_autocast_cpu(self):
def fn(x):
if torch.is_autocast_cpu_enabled():
return x.relu()
else:
return x.sin()
fn_s = torch.jit.script(fn)
x = torch.rand((4, 4)) - 0.5
with torch.cpu.amp.autocast():
self.assertEqual(fn_s(x), fn(x))
with torch.cpu.amp.autocast(enabled=True):
self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any(["is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()]))
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_script_autocast_cuda(self):
def fn(x):
if torch.is_autocast_enabled():
return x.relu()
else:
return x.sin()
fn_s = torch.jit.script(fn)
x = torch.rand((4, 4)) - 0.5
with torch.cpu.amp.autocast():
self.assertEqual(fn_s(x), fn(x))
with torch.cuda.amp.autocast(enabled=True):
self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any(["is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()]))
def test_scripted_aliasing(self):
# torch.is_autocast_enabled should not be able to move inside of the autocast context.
def fn(x):
if torch.is_autocast_enabled():
y = True
else:
y = False
with torch.cuda.amp.autocast(enabled=True):
z = x.relu()
return y, z
fn_s = torch.jit.script(fn)
graph = fn_s.graph
aliasdb = graph.alias_db()
is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
enter_nodes = graph.findAllNodes("prim::Enter")
self.assertEqual(len(is_enabled_nodes), 1)
self.assertEqual(len(enter_nodes), 1)
self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
def test_script_autocast_enable_and_check(self):
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
b1 = torch.is_autocast_cpu_enabled()
v1 = torch.mm(x, y)
with torch.cpu.amp.autocast(enabled=True):
b2 = torch.is_autocast_cpu_enabled()
v2 = torch.mm(x, y)
with torch.cpu.amp.autocast(enabled=False):
b3 = torch.is_autocast_cpu_enabled()
v3 = torch.mm(x, y)
return (v1, b1, v2, b2, v3, b3)
# bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
def check_fn_results(arr):
[v1, b1, v2, b2, v3, b3] = arr
self.assertTrue((v1.dtype == torch.float) != b1)
self.assertTrue((v2.dtype == torch.float) != b2)
self.assertTrue((v3.dtype == torch.float) != b3)
x = torch.rand((2, 2), dtype=torch.float)
y = torch.rand((2, 2), dtype=torch.float)
fn_s = torch.jit.script(fn)
with torch.cpu.amp.autocast(enabled=False):
check_fn_results(fn(x, y))
check_fn_results(fn_s(x, y))
with torch.cpu.amp.autocast(enabled=True):
check_fn_results(fn(x, y))
check_fn_results(fn_s(x, y))
if __name__ == "__main__":
run_tests()