switching to a simple/full executor

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29230

Differential Revision: D18402229

Pulled By: Krovatkin

fbshipit-source-id: 62f4bc9bc89c0c7369359bba1359c22a2fa80f46
This commit is contained in:
Nikolay Korovaiko
2019-11-11 13:39:03 -08:00
committed by Facebook Github Bot
parent cedca377bd
commit 5b702ab52b
19 changed files with 412 additions and 262 deletions

View File

@ -9,16 +9,16 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck
from common_utils import run_tests, IS_SANDCASTLE
from common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
enable_profiling_mode
from textwrap import dedent
from itertools import product, permutations
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell, _inline_everything
from jit_utils import enable_profiling_mode, ProfilingMode, IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR
if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR:
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
@ -123,7 +123,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "no half support with profiling on")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
def test_cuda_half(self):
x = torch.randn(4, 4, dtype=torch.half, device='cuda')
y = torch.randn(4, 4, dtype=torch.half, device='cuda')
@ -303,15 +303,16 @@ class TestFuser(JitTestCase):
funcs = (func2, funcInf, funcOptMin, funcOptMax)
for f, inputs in product(funcs, [[a, b], [a, nan]]):
inp1, inp2 = inputs
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.FULL)
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
c = s(inp1, inp2)
with enable_profiling_mode(ProfilingMode.FULL):
with enable_profiling_mode():
warmup_backward(c.sum())
graph = backward_graph(s)
self.assertAllFused(graph, except_for={'aten::Float'})
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
def test_dropout(self):
def func(x):
x = torch.nn.functional.dropout(x)
@ -461,7 +462,7 @@ class TestFuser(JitTestCase):
self.assertAllFused(ge.graph_for(x, y))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "broken with profiling on")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
@_inline_everything
def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):
@ -552,7 +553,7 @@ class TestFuser(JitTestCase):
"aten::_size_if_not_equal"))
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "broken with profiling on")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
@enable_cpu_fuser
def test_fuser_deduplication(self):
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
@ -905,6 +906,7 @@ class TestFuser(JitTestCase):
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
def test_grad_sum_to_size_elimination(self):
def my_broadcasted_cell(a, b, c):
@ -913,7 +915,7 @@ class TestFuser(JitTestCase):
s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.FULL)
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING)
forward_graph = module.graph_for(s1, s1, s1)
self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
"aten::_size_if_not_equal"))
@ -925,7 +927,7 @@ class TestFuser(JitTestCase):
args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
args = [a.detach_().requires_grad_() for a in args]
# recompile, so we don't trigger bailouts
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.FULL)
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING)
res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
warmup_backward(res.sum(), args)
grads = torch.autograd.grad(res.sum(), args)