mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[JIT] Support scripting torch.is_autocast_enabled() (#81305)"
This reverts commit bcc9084bc444bb68f038c85d9f2b84de42971b58. Reverted https://github.com/pytorch/pytorch/pull/81305 on behalf of https://github.com/malfet due to Broke lite-intepreter builds, see https://github.com/pytorch/pytorch/runs/7550084494?check_suite_focus=true
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import unittest
|
||||
from test_jit import JitTestCase
|
||||
@ -819,101 +819,5 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
continue
|
||||
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
def test_script_autocast_cpu(self):
|
||||
def fn(x):
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
return x.relu()
|
||||
else:
|
||||
return x.sin()
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
x = torch.rand((4, 4)) - 0.5
|
||||
with torch.cpu.amp.autocast():
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
self.assertTrue(any(["is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()]))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_script_autocast_cuda(self):
|
||||
def fn(x):
|
||||
if torch.is_autocast_enabled():
|
||||
return x.relu()
|
||||
else:
|
||||
return x.sin()
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
x = torch.rand((4, 4)) - 0.5
|
||||
with torch.cpu.amp.autocast():
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
self.assertTrue(any(["is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()]))
|
||||
|
||||
|
||||
def test_scripted_aliasing(self):
|
||||
# torch.is_autocast_enabled should not be able to move inside of the autocast context.
|
||||
def fn(x):
|
||||
if torch.is_autocast_enabled():
|
||||
y = True
|
||||
else:
|
||||
y = False
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
z = x.relu()
|
||||
return y, z
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
graph = fn_s.graph
|
||||
|
||||
aliasdb = graph.alias_db()
|
||||
|
||||
is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
|
||||
enter_nodes = graph.findAllNodes("prim::Enter")
|
||||
|
||||
self.assertEqual(len(is_enabled_nodes), 1)
|
||||
self.assertEqual(len(enter_nodes), 1)
|
||||
|
||||
self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
|
||||
|
||||
|
||||
def test_script_autocast_enable_and_check(self):
|
||||
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
|
||||
b1 = torch.is_autocast_cpu_enabled()
|
||||
v1 = torch.mm(x, y)
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
b2 = torch.is_autocast_cpu_enabled()
|
||||
v2 = torch.mm(x, y)
|
||||
with torch.cpu.amp.autocast(enabled=False):
|
||||
b3 = torch.is_autocast_cpu_enabled()
|
||||
v3 = torch.mm(x, y)
|
||||
return (v1, b1, v2, b2, v3, b3)
|
||||
|
||||
# bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
|
||||
def check_fn_results(arr):
|
||||
[v1, b1, v2, b2, v3, b3] = arr
|
||||
self.assertTrue((v1.dtype == torch.float) != b1)
|
||||
self.assertTrue((v2.dtype == torch.float) != b2)
|
||||
self.assertTrue((v3.dtype == torch.float) != b3)
|
||||
|
||||
x = torch.rand((2, 2), dtype=torch.float)
|
||||
y = torch.rand((2, 2), dtype=torch.float)
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=False):
|
||||
check_fn_results(fn(x, y))
|
||||
check_fn_results(fn_s(x, y))
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
check_fn_results(fn(x, y))
|
||||
check_fn_results(fn_s(x, y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user