mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
switching to a simple/full executor
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29230 Differential Revision: D18402229 Pulled By: Krovatkin fbshipit-source-id: 62f4bc9bc89c0c7369359bba1359c22a2fa80f46
This commit is contained in:
committed by
Facebook Github Bot
parent
cedca377bd
commit
5b702ab52b
@ -9,16 +9,16 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing import FileCheck
|
||||
|
||||
from common_utils import run_tests, IS_SANDCASTLE
|
||||
from common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
|
||||
enable_profiling_mode
|
||||
from textwrap import dedent
|
||||
from itertools import product, permutations
|
||||
|
||||
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
|
||||
backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
|
||||
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell, _inline_everything
|
||||
from jit_utils import enable_profiling_mode, ProfilingMode, IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR
|
||||
|
||||
if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR:
|
||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(True)
|
||||
|
||||
@ -123,7 +123,7 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
|
||||
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "no half support with profiling on")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
def test_cuda_half(self):
|
||||
x = torch.randn(4, 4, dtype=torch.half, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.half, device='cuda')
|
||||
@ -303,15 +303,16 @@ class TestFuser(JitTestCase):
|
||||
funcs = (func2, funcInf, funcOptMin, funcOptMax)
|
||||
for f, inputs in product(funcs, [[a, b], [a, nan]]):
|
||||
inp1, inp2 = inputs
|
||||
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.FULL)
|
||||
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
|
||||
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
|
||||
c = s(inp1, inp2)
|
||||
with enable_profiling_mode(ProfilingMode.FULL):
|
||||
with enable_profiling_mode():
|
||||
warmup_backward(c.sum())
|
||||
graph = backward_graph(s)
|
||||
self.assertAllFused(graph, except_for={'aten::Float'})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
def test_dropout(self):
|
||||
def func(x):
|
||||
x = torch.nn.functional.dropout(x)
|
||||
@ -461,7 +462,7 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "broken with profiling on")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||
@_inline_everything
|
||||
def test_fuse_decompose_normalization(self):
|
||||
class ResLike(torch.jit.ScriptModule):
|
||||
@ -552,7 +553,7 @@ class TestFuser(JitTestCase):
|
||||
"aten::_size_if_not_equal"))
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "broken with profiling on")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
|
||||
@enable_cpu_fuser
|
||||
def test_fuser_deduplication(self):
|
||||
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
|
||||
@ -905,6 +906,7 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
def test_grad_sum_to_size_elimination(self):
|
||||
|
||||
def my_broadcasted_cell(a, b, c):
|
||||
@ -913,7 +915,7 @@ class TestFuser(JitTestCase):
|
||||
s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
|
||||
s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
|
||||
|
||||
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.FULL)
|
||||
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING)
|
||||
forward_graph = module.graph_for(s1, s1, s1)
|
||||
self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
@ -925,7 +927,7 @@ class TestFuser(JitTestCase):
|
||||
args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
|
||||
args = [a.detach_().requires_grad_() for a in args]
|
||||
# recompile, so we don't trigger bailouts
|
||||
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.FULL)
|
||||
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING)
|
||||
res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
|
||||
warmup_backward(res.sum(), args)
|
||||
grads = torch.autograd.grad(res.sum(), args)
|
||||
|
Reference in New Issue
Block a user