mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -1,22 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from itertools import product, permutations
|
||||
from textwrap import dedent
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from test_jit import (
|
||||
backward_graph,
|
||||
all_backward_graphs,
|
||||
get_lstm_inputs,
|
||||
get_milstm_inputs,
|
||||
LSTMCellC,
|
||||
LSTMCellF,
|
||||
LSTMCellS,
|
||||
MiLSTMCell,
|
||||
)
|
||||
from torch.testing import FileCheck
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
|
||||
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
|
||||
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
|
||||
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward
|
||||
from textwrap import dedent
|
||||
from itertools import product, permutations
|
||||
from torch.testing._internal.common_cuda import with_tf32_off
|
||||
|
||||
from test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
|
||||
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
IS_SANDCASTLE,
|
||||
ProfilingMode,
|
||||
GRAPH_EXECUTOR,
|
||||
enable_profiling_mode_for_profiling_tests,
|
||||
IS_WINDOWS,
|
||||
TemporaryDirectoryName,
|
||||
shell,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import (
|
||||
JitTestCase,
|
||||
enable_cpu_fuser,
|
||||
_inline_everything,
|
||||
RUN_CUDA,
|
||||
RUN_CUDA_HALF,
|
||||
RUN_CUDA_MULTI_GPU,
|
||||
warmup_backward,
|
||||
)
|
||||
|
||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
@ -24,7 +46,7 @@ if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
|
||||
|
||||
def strip_profiling_nodes(nodes):
|
||||
profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut'])
|
||||
profiling_opcodes = set(["prim::BailoutTemplate", "prim::BailOut"])
|
||||
return [n for n in nodes if n.kind() not in profiling_opcodes]
|
||||
|
||||
|
||||
@ -39,18 +61,29 @@ def warmup_forward(f, *args):
|
||||
class TestFuser(JitTestCase):
|
||||
def assertAllFused(self, graph, except_for=()):
|
||||
|
||||
diff_graphs = [n for n in graph.nodes() if n.kind() == 'prim::DifferentiableGraph']
|
||||
diff_graphs = [
|
||||
n for n in graph.nodes() if n.kind() == "prim::DifferentiableGraph"
|
||||
]
|
||||
if len(diff_graphs) > 0:
|
||||
self.assertEqual(len(diff_graphs), 1)
|
||||
graph = diff_graphs[0].g('Subgraph')
|
||||
graph = diff_graphs[0].g("Subgraph")
|
||||
|
||||
allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate',
|
||||
'prim::BailOut', 'prim::TupleConstruct'} | set(except_for)
|
||||
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
|
||||
'got {}'.format(graph))
|
||||
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
|
||||
allowed_nodes = {
|
||||
"prim::Constant",
|
||||
"prim::FusionGroup",
|
||||
"prim::BailoutTemplate",
|
||||
"prim::BailOut",
|
||||
"prim::TupleConstruct",
|
||||
} | set(except_for)
|
||||
self.assertTrue(
|
||||
all(node.kind() in allowed_nodes for node in graph.nodes()),
|
||||
"got {}".format(graph),
|
||||
)
|
||||
self.assertTrue(
|
||||
[node.kind() for node in graph.nodes()].count("prim::FusionGroup") == 1
|
||||
)
|
||||
|
||||
def _test_fused_abs(self, device='cpu'):
|
||||
def _test_fused_abs(self, device="cpu"):
|
||||
def func(x):
|
||||
return x.abs() * 2
|
||||
|
||||
@ -67,11 +100,15 @@ class TestFuser(JitTestCase):
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_abs_cpu_unicode_temp_dir(self):
|
||||
with TemporaryDirectoryName(suffix='中文') as dname:
|
||||
with TemporaryDirectoryName(suffix="中文") as dname:
|
||||
shell_env = os.environ.copy()
|
||||
shell_env['TMP'] = dname
|
||||
cmd = [sys.executable, os.path.basename(__file__), type(self).__name__ + '.test_abs_cpu']
|
||||
legacy_jit_flag = '--jit_executor=legacy'
|
||||
shell_env["TMP"] = dname
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.basename(__file__),
|
||||
type(self).__name__ + ".test_abs_cpu",
|
||||
]
|
||||
legacy_jit_flag = "--jit_executor=legacy"
|
||||
for v in sys.argv:
|
||||
if v == legacy_jit_flag:
|
||||
cmd.append(legacy_jit_flag)
|
||||
@ -103,9 +140,15 @@ class TestFuser(JitTestCase):
|
||||
z1, z2 = (x + y).chunk(2, dim=1)
|
||||
return z1 * z2
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
traced_f = torch.jit.trace(f, (x, y,))
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
traced_f = torch.jit.trace(
|
||||
f,
|
||||
(
|
||||
x,
|
||||
y,
|
||||
),
|
||||
)
|
||||
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@ -114,18 +157,21 @@ class TestFuser(JitTestCase):
|
||||
return x * scale + shift
|
||||
|
||||
inputs = [
|
||||
torch.randn(4, 4, dtype=torch.float, device='cuda'),
|
||||
torch.randn(4, dtype=torch.float, device='cuda'),
|
||||
torch.randn(4, dtype=torch.float, device='cuda'),
|
||||
torch.randn(4, 4, dtype=torch.float, device="cuda"),
|
||||
torch.randn(4, dtype=torch.float, device="cuda"),
|
||||
torch.randn(4, dtype=torch.float, device="cuda"),
|
||||
]
|
||||
ge = self.checkTrace(scaleshift, inputs)
|
||||
self.assertAllFused(ge.graph_for(*inputs))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no bfloat support with profiling on")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no bfloat support with profiling on"
|
||||
)
|
||||
def test_cuda_bfloat16(self):
|
||||
def foo(x, y):
|
||||
return (x + y).relu()
|
||||
|
||||
m = torch.jit.script(foo)
|
||||
x = torch.randn(65536).cuda().bfloat16()
|
||||
y = torch.randn_like(x)
|
||||
@ -133,16 +179,14 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on"
|
||||
)
|
||||
def test_cuda_half(self):
|
||||
x = torch.randn(4, 4, dtype=torch.half, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.half, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.half, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.half, device="cuda")
|
||||
|
||||
funcs = [
|
||||
self.fn_test_comparison_gt_lt,
|
||||
self.fn_test_relu,
|
||||
self.fn_test_exp
|
||||
]
|
||||
funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp]
|
||||
|
||||
# Note: Non fused inputs must be float to prevent loss of precision
|
||||
inputs = (x.float(), y.float())
|
||||
@ -161,9 +205,17 @@ class TestFuser(JitTestCase):
|
||||
# Verifies gradients
|
||||
for output, fusion_output in zip(outputs_half, fusion_outputs):
|
||||
grads = torch.autograd.grad(
|
||||
output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
|
||||
output.float().sum(),
|
||||
local_inputs,
|
||||
allow_unused=True,
|
||||
retain_graph=True,
|
||||
)
|
||||
fusion_grads = torch.autograd.grad(
|
||||
fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
|
||||
fusion_output.sum(),
|
||||
local_fusion_inputs,
|
||||
allow_unused=True,
|
||||
retain_graph=True,
|
||||
)
|
||||
grads_half = [t.half() for t in grads]
|
||||
self.assertEqual(grads_half, fusion_grads)
|
||||
|
||||
@ -177,8 +229,8 @@ class TestFuser(JitTestCase):
|
||||
|
||||
# NOTE: y is broadcastable to x, but output of f(x, y) should have
|
||||
# shape 3x4, and not 4x4.
|
||||
x = torch.randn(2, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(2, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(1, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
scripted = self.checkScript(f, (x, y))
|
||||
self.assertAllFused(scripted.graph_for(x, y))
|
||||
@ -201,7 +253,7 @@ class TestFuser(JitTestCase):
|
||||
a, b, c = x.chunk(3, 1)
|
||||
return a * b + c
|
||||
|
||||
inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
|
||||
inputs = [torch.randn(10, 6, dtype=torch.float, device="cuda")]
|
||||
|
||||
ge = self.checkScript(fn, inputs)
|
||||
graph = ge.graph_for(*inputs)
|
||||
@ -209,7 +261,7 @@ class TestFuser(JitTestCase):
|
||||
FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
|
||||
|
||||
@staticmethod
|
||||
def _test_chunk_correctness(self, device='cpu'):
|
||||
def _test_chunk_correctness(self, device="cpu"):
|
||||
def chunk_4_0(x):
|
||||
x0, x1, x2, x3 = x.chunk(4, 0)
|
||||
return x0 + x1 + x2 + x3
|
||||
@ -226,10 +278,8 @@ class TestFuser(JitTestCase):
|
||||
tensors = [
|
||||
# splitSize = 1
|
||||
torch.randn(4, 4, 4, dtype=torch.float, device=device),
|
||||
|
||||
# contiguous case
|
||||
torch.randn(12, 8, 16, dtype=torch.float, device=device),
|
||||
|
||||
# non-contiguous case
|
||||
torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
|
||||
]
|
||||
@ -241,11 +291,11 @@ class TestFuser(JitTestCase):
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
def test_chunk_correctness(self):
|
||||
return self._test_chunk_correctness(self, 'cpu')
|
||||
return self._test_chunk_correctness(self, "cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
||||
def test_chunk_correctness_cuda(self):
|
||||
return self._test_chunk_correctness(self, 'cuda')
|
||||
return self._test_chunk_correctness(self, "cuda")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_chunk_distributes_cuda(self):
|
||||
@ -253,13 +303,14 @@ class TestFuser(JitTestCase):
|
||||
z1, z2 = (x + y).chunk(2, dim=1)
|
||||
return z1 * z2
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
ge = self.checkTrace(f, (x, y))
|
||||
graph = ge.graph_for(x, y)
|
||||
FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \
|
||||
.check_count('ConstantChunk', 2, exactly=True).run(str(graph))
|
||||
FileCheck().check("broadcast_tensors").check(
|
||||
"with prim::FusionGroup_"
|
||||
).check_count("ConstantChunk", 2, exactly=True).run(str(graph))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_chunk_motion_deduplicates_inputs(self):
|
||||
@ -274,12 +325,12 @@ class TestFuser(JitTestCase):
|
||||
return z0 * z1
|
||||
|
||||
inputs = [
|
||||
torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
|
||||
torch.tensor([1.1, 1.2], device="cuda", dtype=torch.float),
|
||||
]
|
||||
for func in [func1, func2]:
|
||||
module = self.checkScript(func, inputs)
|
||||
forward_graph = module.graph_for(*inputs)
|
||||
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
|
||||
self.assertGraphContainsExactly(forward_graph, "prim::FusionGroup", 1)
|
||||
fusion_group = list(forward_graph.nodes())[-1]
|
||||
self.assertEqual(len(list(fusion_group.inputs())), 1)
|
||||
|
||||
@ -294,10 +345,10 @@ class TestFuser(JitTestCase):
|
||||
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
|
||||
|
||||
inputs = [
|
||||
torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
|
||||
torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
|
||||
torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
|
||||
torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
|
||||
torch.randn(5, 2, 3, dtype=torch.float, device="cuda"),
|
||||
torch.randn(5, 6, 3, dtype=torch.float, device="cuda"),
|
||||
torch.randn(10, 2, 3, dtype=torch.float, device="cuda"),
|
||||
torch.randn(5, 2, 6, dtype=torch.float, device="cuda"),
|
||||
]
|
||||
|
||||
ge = self.checkScript(fn, inputs)
|
||||
@ -313,11 +364,9 @@ class TestFuser(JitTestCase):
|
||||
|
||||
a = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
b = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
nan = torch.tensor(float('nan'), dtype=torch.float, device="cuda")
|
||||
nan = torch.tensor(float("nan"), dtype=torch.float, device="cuda")
|
||||
|
||||
for f, inputs in product(
|
||||
(tmax, tmin),
|
||||
([a, b], [a, nan], [b, nan])):
|
||||
for f, inputs in product((tmax, tmin), ([a, b], [a, nan], [b, nan])):
|
||||
s = self.checkScript(f, inputs)
|
||||
self.assertAllFused(s.graph_for(*inputs))
|
||||
|
||||
@ -327,7 +376,7 @@ class TestFuser(JitTestCase):
|
||||
return torch.clamp(a + b, min=0, max=2)
|
||||
|
||||
def funcInf(a, b):
|
||||
return torch.clamp(a + b, min=0, max=float('inf'))
|
||||
return torch.clamp(a + b, min=0, max=float("inf"))
|
||||
|
||||
def funcOptMin(a, b):
|
||||
return torch.clamp(a + b, max=2)
|
||||
@ -335,37 +384,44 @@ class TestFuser(JitTestCase):
|
||||
def funcOptMax(a, b):
|
||||
return torch.clamp(a + b, min=0)
|
||||
|
||||
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
|
||||
b = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')
|
||||
a = torch.randn(4, 4, dtype=torch.float, device="cuda", requires_grad=True)
|
||||
b = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
nan = torch.tensor(float("nan"), dtype=torch.float, device="cuda")
|
||||
|
||||
funcs = (func2, funcInf, funcOptMin, funcOptMax)
|
||||
for f, inputs in product(funcs, [[a, b], [a, nan]]):
|
||||
f.__disable_jit_function_caching__ = True
|
||||
inp1, inp2 = inputs
|
||||
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
|
||||
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
|
||||
self.assertAllFused(
|
||||
s.graph_for(inp1, inp2),
|
||||
except_for={"aten::size", "aten::_size_if_not_equal"},
|
||||
)
|
||||
c = s(inp1, inp2)
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
warmup_backward(c.sum())
|
||||
graph = backward_graph(s)
|
||||
self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
|
||||
self.assertAllFused(
|
||||
graph, except_for={"aten::Float", "aten::_grad_sum_to_size"}
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on"
|
||||
)
|
||||
def test_dropout(self):
|
||||
def func(x):
|
||||
x = torch.nn.functional.dropout(x)
|
||||
return torch.nn.functional.relu(x)
|
||||
|
||||
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
|
||||
a = torch.randn(4, 4, dtype=torch.float, device="cuda", requires_grad=True)
|
||||
s = torch.jit.script(func)
|
||||
c = s(a)
|
||||
c = s(a)
|
||||
warmup_backward(c.sum())
|
||||
# skip_check to skip extra bailout nodes in between
|
||||
graph = backward_graph(s, skip_check=True)
|
||||
self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'})
|
||||
self.assertAllFused(graph, except_for={"aten::div", "prim::Constant"})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_comparison_eq_ne(self):
|
||||
@ -376,8 +432,8 @@ class TestFuser(JitTestCase):
|
||||
z = z * mask + y
|
||||
return z
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
ge = self.checkTrace(f, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
@ -392,8 +448,8 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_comparison_gt_lt_cuda(self):
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
@ -407,21 +463,27 @@ class TestFuser(JitTestCase):
|
||||
z = z * mask + y
|
||||
return z
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
ge = self.checkTrace(f, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
x.requires_grad_(True)
|
||||
y.requires_grad_(True)
|
||||
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
self.assertAllFused(
|
||||
ge.graph_for(x, y),
|
||||
except_for=(
|
||||
"aten::size",
|
||||
"prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal",
|
||||
),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_addcmul_cuda(self):
|
||||
t = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
|
||||
t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
t = torch.randn(1, 4, dtype=torch.float, device="cuda")
|
||||
t1 = torch.randn(4, 1, dtype=torch.float, device="cuda")
|
||||
t2 = torch.randn(1, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
def foo(t, t1, t2):
|
||||
return t.addcmul(t + 1, t2, value=0.1)
|
||||
@ -438,9 +500,9 @@ class TestFuser(JitTestCase):
|
||||
# lifetimes in Python.
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_lerp(self):
|
||||
start = torch.randn(4, 1, dtype=torch.float, device='cuda')
|
||||
end = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
weight = torch.tensor(0.5, dtype=torch.float, device='cuda')
|
||||
start = torch.randn(4, 1, dtype=torch.float, device="cuda")
|
||||
end = torch.randn(1, 4, dtype=torch.float, device="cuda")
|
||||
weight = torch.tensor(0.5, dtype=torch.float, device="cuda")
|
||||
|
||||
# scalar weight overload
|
||||
def foo_weight_scalar(start, end):
|
||||
@ -460,8 +522,8 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_concat_cuda(self):
|
||||
hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
|
||||
cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
|
||||
hx = torch.randn(3, 20, dtype=torch.float, device="cuda")
|
||||
cx = torch.randn(3, 20, dtype=torch.float, device="cuda")
|
||||
|
||||
def foo(hx, cx):
|
||||
return torch.cat((hx + cx, hx * cx))
|
||||
@ -481,22 +543,22 @@ class TestFuser(JitTestCase):
|
||||
w = torch.cat([x1, y1])
|
||||
return w + z
|
||||
|
||||
x = torch.randn(2, 2, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(2, 2, dtype=torch.float, device='cuda')
|
||||
z = torch.randn(4, 2, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(2, 2, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(2, 2, dtype=torch.float, device="cuda")
|
||||
z = torch.randn(4, 2, dtype=torch.float, device="cuda")
|
||||
ge = self.checkTrace(fn, (x, y, z))
|
||||
graph = ge.graph_for(x, y, z)
|
||||
self.assertAllFused(graph, except_for={'aten::add'})
|
||||
self.assertAllFused(graph, except_for={"aten::add"})
|
||||
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
||||
|
||||
@staticmethod
|
||||
def fn_test_exp(x, y):
|
||||
return (x + .5 * y).exp()
|
||||
return (x + 0.5 * y).exp()
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_exp_cuda(self):
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
ge = self.checkTrace(self.fn_test_exp, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
@ -519,8 +581,8 @@ class TestFuser(JitTestCase):
|
||||
model = ResLike(nm).cuda()
|
||||
model_noopt = ResLike(nm).cuda()
|
||||
model_noopt.load_state_dict(model.state_dict())
|
||||
x = torch.randn(2, 16, 8, 8, device='cuda')
|
||||
y = torch.randn(2, 16, 8, 8, device='cuda')
|
||||
x = torch.randn(2, 16, 8, 8, device="cuda")
|
||||
y = torch.randn(2, 16, 8, 8, device="cuda")
|
||||
|
||||
# FIXME: We need differentiation for CNNs for this optimization to trigger
|
||||
with torch.no_grad():
|
||||
@ -541,28 +603,35 @@ class TestFuser(JitTestCase):
|
||||
self.assertNotIn(node_not_in_graph, rep)
|
||||
self.assertIn(node_not_in_graph, rep_noopt)
|
||||
|
||||
fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
|
||||
fusion_groups = [
|
||||
node for node in graph.nodes() if node.kind() == "prim::FusionGroup"
|
||||
]
|
||||
self.assertEqual(len(fusion_groups), 1)
|
||||
fused_graph = str(fusion_groups[0].g('Subgraph'))
|
||||
fused_graph = str(fusion_groups[0].g("Subgraph"))
|
||||
for node_in_fusegraph in in_fusegraph:
|
||||
self.assertIn(node_in_fusegraph, fused_graph)
|
||||
|
||||
# test for batchnorm decompose
|
||||
bm = nn.BatchNorm2d(16)
|
||||
test_norm_decompose(bm, ['aten::batch_norm_update_stats'],
|
||||
['aten::batch_norm('], ['aten::sqrt'])
|
||||
test_norm_decompose(
|
||||
bm, ["aten::batch_norm_update_stats"], ["aten::batch_norm("], ["aten::sqrt"]
|
||||
)
|
||||
|
||||
# test for layernorm decompose
|
||||
lm = nn.LayerNorm(8)
|
||||
test_norm_decompose(lm, ['aten::batch_norm_stats'],
|
||||
['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add'])
|
||||
test_norm_decompose(
|
||||
lm,
|
||||
["aten::batch_norm_stats"],
|
||||
["aten::layer_norm("],
|
||||
["aten::sub", "aten::mul", "aten::add"],
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_threshold(self):
|
||||
def f(x):
|
||||
return torch.threshold(x, 0, -10) + x + x + x
|
||||
|
||||
x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
|
||||
x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device="cuda")
|
||||
scripted = self.checkScript(f, (x,))
|
||||
self.assertAllFused(scripted.graph_for(x))
|
||||
|
||||
@ -571,7 +640,7 @@ class TestFuser(JitTestCase):
|
||||
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
|
||||
return p * (x * x + x)
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
p = 3
|
||||
scripted = self.checkScript(fn_test_scalar_arg, (x, p))
|
||||
self.assertAllFused(scripted.graph_for(x, p))
|
||||
@ -585,8 +654,14 @@ class TestFuser(JitTestCase):
|
||||
|
||||
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
|
||||
out = scripted(x, p)
|
||||
self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
self.assertAllFused(
|
||||
scripted.graph_for(x, p),
|
||||
except_for=(
|
||||
"aten::size",
|
||||
"prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal",
|
||||
),
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@unittest.skip("deduplicating introduces aliasing in backward graph's outputs")
|
||||
@ -600,8 +675,14 @@ class TestFuser(JitTestCase):
|
||||
b = torch.randn(5, 5, requires_grad=True)
|
||||
a = torch.randn(5, 5, requires_grad=True)
|
||||
s = self.checkScript(f, (a, b))
|
||||
self.assertAllFused(s.graph_for(a, b), except_for={
|
||||
'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'})
|
||||
self.assertAllFused(
|
||||
s.graph_for(a, b),
|
||||
except_for={
|
||||
"aten::size",
|
||||
"aten::_size_if_not_equal",
|
||||
"prim::BroadcastSizes",
|
||||
},
|
||||
)
|
||||
|
||||
c = s(a, b)
|
||||
results = warmup_backward(c.sum(), [a, b])
|
||||
@ -613,7 +694,9 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
|
||||
@unittest.skip(
|
||||
"temporarily disabled because fusion was restricted in fixing #22833"
|
||||
)
|
||||
def test_fuser_iou(self):
|
||||
# This checks if most of Intersection over Union is fused.
|
||||
# In particular, the backward contains many _grad_sum_to_size.
|
||||
@ -623,8 +706,8 @@ class TestFuser(JitTestCase):
|
||||
rbx = torch.min(b1x2, b2x2)
|
||||
rby = torch.min(b1y2, b2y2)
|
||||
|
||||
w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M]
|
||||
h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M]
|
||||
w = (rbx - ltx).clamp(min=0, max=float("inf")) # [N,M]
|
||||
h = (rby - lty).clamp(min=0, max=float("inf")) # [N,M]
|
||||
inter = w * h # [N,M]
|
||||
|
||||
area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1]
|
||||
@ -645,14 +728,27 @@ class TestFuser(JitTestCase):
|
||||
b2y2 = box2[:, 3].unsqueeze(0)
|
||||
|
||||
s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
|
||||
self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
|
||||
except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
|
||||
self.assertAllFused(
|
||||
s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
|
||||
except_for={
|
||||
"aten::size",
|
||||
"prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal",
|
||||
},
|
||||
)
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests(True):
|
||||
c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
|
||||
warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
|
||||
graph = backward_graph(s)
|
||||
self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
|
||||
self.assertAllFused(
|
||||
graph,
|
||||
except_for={
|
||||
"aten::size",
|
||||
"prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal",
|
||||
},
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
|
||||
@ -689,8 +785,8 @@ class TestFuser(JitTestCase):
|
||||
|
||||
inputs = [
|
||||
torch.randn(4, 4, dtype=torch.float),
|
||||
torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
|
||||
torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
|
||||
torch.randn(4, 4, dtype=torch.float, device="cuda:0"),
|
||||
torch.randn(4, 4, dtype=torch.float, device="cuda:1"),
|
||||
]
|
||||
|
||||
prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
|
||||
@ -699,14 +795,15 @@ class TestFuser(JitTestCase):
|
||||
# should reuse the same KernelSpec in the KernelSpec cache.
|
||||
ge = self.checkScript(fn, inputs)
|
||||
self.assertGraphContainsExactly(
|
||||
ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
|
||||
ge.graph_for(*inputs), "prim::FusionGroup", 3, True
|
||||
)
|
||||
new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
|
||||
# XXX: This assumes that the same kernel isn't already used by another test
|
||||
self.assertEqual(new_cache_size - prev_cache_size, 1)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
|
||||
def test_nonzero_device_cuda(self):
|
||||
device = 'cuda:' + str(1)
|
||||
device = "cuda:" + str(1)
|
||||
x = torch.tensor([0.4], dtype=torch.float, device=device)
|
||||
y = torch.tensor([0.7], dtype=torch.float, device=device)
|
||||
|
||||
@ -718,30 +815,33 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_lstm_cuda(self):
|
||||
inputs = get_lstm_inputs('cuda', training=True)
|
||||
inputs = get_lstm_inputs("cuda", training=True)
|
||||
module = self.checkScript(LSTMCellS, inputs)
|
||||
return
|
||||
forward_graph = module.graph_for(*inputs)
|
||||
self.assertGraphContainsExactly(
|
||||
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
|
||||
forward_graph, "prim::FusionGroup", 1, consider_subgraphs=True
|
||||
)
|
||||
self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
|
||||
# Everything is differentiable but TupleConstruct return
|
||||
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
|
||||
.check_next("return").run(str(forward_graph))
|
||||
FileCheck().check("DifferentiableGraph").check_next(
|
||||
"TupleConstruct"
|
||||
).check_next("return").run(str(forward_graph))
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests(True):
|
||||
hy, cy = module(*inputs)
|
||||
warmup_backward((hy + cy).sum())
|
||||
backward = backward_graph(module)
|
||||
self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
|
||||
"aten::_grad_sum_to_size"))
|
||||
self.assertAllFused(
|
||||
backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size")
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
# By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
|
||||
# We want float tensors to be computed at full precision in order to use the default precision
|
||||
@with_tf32_off
|
||||
def test_lstm_concat_cuda(self):
|
||||
inputs = get_lstm_inputs('cuda')
|
||||
inputs = get_lstm_inputs("cuda")
|
||||
ge = self.checkTrace(LSTMCellC, inputs)
|
||||
graph = ge.graph_for(*inputs)
|
||||
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
||||
@ -750,23 +850,25 @@ class TestFuser(JitTestCase):
|
||||
def test_lstm_gates_permutations_cuda(self):
|
||||
# lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
|
||||
# Test that any permutation of this will still result in one FusionGroup.
|
||||
choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
|
||||
template = dedent('''
|
||||
choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"]
|
||||
template = dedent(
|
||||
"""
|
||||
def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
|
||||
gates = {} + {} + {} + {}
|
||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
||||
return ingate * forgetgate * cellgate * outgate
|
||||
''')
|
||||
"""
|
||||
)
|
||||
for permutation in permutations(choices, len(choices)):
|
||||
code = template.format(*permutation)
|
||||
scope = {}
|
||||
exec(code, globals(), scope)
|
||||
cu = torch.jit.CompilationUnit(code)
|
||||
|
||||
inputs = get_lstm_inputs('cuda', training=False)
|
||||
self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
|
||||
inputs = get_lstm_inputs("cuda", training=False)
|
||||
self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs))
|
||||
forward_graph = cu.cell.graph_for(*inputs)
|
||||
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
|
||||
self.assertGraphContainsExactly(forward_graph, "prim::FusionGroup", 1)
|
||||
|
||||
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@ -774,59 +876,71 @@ class TestFuser(JitTestCase):
|
||||
# We want float tensors to be computed at full precision in order to use the default precision
|
||||
@with_tf32_off
|
||||
def test_lstm_traced_cuda(self):
|
||||
inputs = get_lstm_inputs('cuda')
|
||||
inputs = get_lstm_inputs("cuda")
|
||||
ge = self.checkTrace(LSTMCellF, inputs)
|
||||
graph = ge.graph_for(*inputs)
|
||||
# .check_not("aten::add") don't get pulled into FusionGroup because of BailOuts
|
||||
FileCheck().check_not("Chunk").check_not("aten::sigmoid") \
|
||||
.check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
|
||||
.check_next("return").check_not("FusionGroup_2").run(str(graph))
|
||||
FileCheck().check_not("Chunk").check_not("aten::sigmoid").check_not(
|
||||
"aten::tanh"
|
||||
).check("FusionGroup").check_next("TupleConstruct").check_next(
|
||||
"return"
|
||||
).check_not(
|
||||
"FusionGroup_2"
|
||||
).run(
|
||||
str(graph)
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
|
||||
@enable_cpu_fuser
|
||||
def test_lstm_traced_cpu(self):
|
||||
inputs = get_lstm_inputs('cpu')
|
||||
inputs = get_lstm_inputs("cpu")
|
||||
try:
|
||||
ge = self.checkTrace(LSTMCellF, inputs)
|
||||
graph = ge.graph_for(*inputs)
|
||||
FileCheck.check("FusionGroup").run(str(graph))
|
||||
except RuntimeError as e:
|
||||
if 'Failed to compile' in e.args[0]:
|
||||
warnings.warn('CPU fuser test has failed! This is not a hard failure, '
|
||||
'because the kernels sometimes trigger bugs in compilers '
|
||||
'(most notably GCC 7.2).')
|
||||
raise unittest.SkipTest('Failed to compile') from e
|
||||
if "Failed to compile" in e.args[0]:
|
||||
warnings.warn(
|
||||
"CPU fuser test has failed! This is not a hard failure, "
|
||||
"because the kernels sometimes trigger bugs in compilers "
|
||||
"(most notably GCC 7.2)."
|
||||
)
|
||||
raise unittest.SkipTest("Failed to compile") from e
|
||||
else:
|
||||
raise
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_milstm_cuda(self):
|
||||
inputs = get_milstm_inputs('cuda', training=True)
|
||||
inputs = get_milstm_inputs("cuda", training=True)
|
||||
module = self.checkScript(MiLSTMCell, inputs)
|
||||
forward_graph = module.graph_for(*inputs)
|
||||
self.assertGraphContainsExactly(
|
||||
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
|
||||
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
|
||||
.check_next("return").check("FusionGroup").run(str(forward_graph))
|
||||
forward_graph, "prim::FusionGroup", 1, consider_subgraphs=True
|
||||
)
|
||||
FileCheck().check("DifferentiableGraph").check_next(
|
||||
"TupleConstruct"
|
||||
).check_next("return").check("FusionGroup").run(str(forward_graph))
|
||||
hy, cy = module(*inputs)
|
||||
warmup_backward((hy + cy).sum())
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor"
|
||||
)
|
||||
def test_rand_cuda(self):
|
||||
class M(torch.jit.ScriptModule):
|
||||
__constants__ = ['d']
|
||||
__constants__ = ["d"]
|
||||
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.d = torch.device('cuda')
|
||||
self.d = torch.device("cuda")
|
||||
|
||||
@torch.jit.script_method
|
||||
def create(self, x):
|
||||
return x * x + x + torch.rand_like(x)
|
||||
|
||||
x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
|
||||
x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda")
|
||||
m = M()
|
||||
out1 = m.create(x)
|
||||
out2 = m.create(x)
|
||||
@ -839,12 +953,12 @@ class TestFuser(JitTestCase):
|
||||
|
||||
@staticmethod
|
||||
def fn_test_relu(x, y):
|
||||
return F.relu(x + .5 * y)
|
||||
return F.relu(x + 0.5 * y)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_relu_cuda(self):
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
ge = self.checkTrace(self.fn_test_relu, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
@ -854,33 +968,47 @@ class TestFuser(JitTestCase):
|
||||
def fn_test_erf(x):
|
||||
return F.relu(torch.erf(x) - torch.erfc(x))
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
ge = self.checkTrace(fn_test_erf, (x,))
|
||||
self.assertAllFused(ge.graph_for(x))
|
||||
x.requires_grad_(True)
|
||||
ge = self.checkTrace(fn_test_erf, (x,))
|
||||
self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
self.assertAllFused(
|
||||
ge.graph_for(x),
|
||||
except_for=(
|
||||
"aten::size",
|
||||
"prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal",
|
||||
),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor"
|
||||
)
|
||||
def test_rand_broadcast_cuda(self):
|
||||
def fn_test_rand(x, y):
|
||||
r = torch.rand_like(y)
|
||||
return r * x + x
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
script_f = torch.jit.script(fn_test_rand)
|
||||
out = script_f(x, y)
|
||||
self.assertAllFused(script_f.graph_for(x, y))
|
||||
x.requires_grad_(True)
|
||||
out = script_f(x, y)
|
||||
self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
self.assertAllFused(
|
||||
script_f.graph_for(x, y),
|
||||
except_for=(
|
||||
"aten::size",
|
||||
"prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal",
|
||||
),
|
||||
)
|
||||
# test that broadcasting random produces correct results
|
||||
x = torch.ones(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.ones(4, dtype=torch.float, device='cuda')
|
||||
x = torch.ones(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.ones(4, dtype=torch.float, device="cuda")
|
||||
out = script_f(x, y)
|
||||
self.assertEqual(out[0], out[1])
|
||||
|
||||
@ -890,8 +1018,8 @@ class TestFuser(JitTestCase):
|
||||
def fn(x, y):
|
||||
return 2 * x + y
|
||||
|
||||
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
|
||||
y = torch.tensor(1, dtype=torch.float, device='cpu')
|
||||
x = torch.tensor(0.1, dtype=torch.float, device="cpu")
|
||||
y = torch.tensor(1, dtype=torch.float, device="cpu")
|
||||
ge = self.checkScript(fn, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@ -899,8 +1027,9 @@ class TestFuser(JitTestCase):
|
||||
def test_small_constant_cuda(self):
|
||||
def fn_test_small_constant(x, y):
|
||||
return (1e-8 * x + 5e-9 * y) * 1e8
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
|
||||
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
|
||||
|
||||
ge = self.checkTrace(fn_test_small_constant, (x, y))
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
@ -908,7 +1037,7 @@ class TestFuser(JitTestCase):
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_tensor_scalar_ops_cuda(self):
|
||||
def should_fuse(x):
|
||||
z = 3.
|
||||
z = 3.0
|
||||
y = x + z
|
||||
return x * y
|
||||
|
||||
@ -918,17 +1047,18 @@ class TestFuser(JitTestCase):
|
||||
y = x + int(z)
|
||||
return x * y
|
||||
|
||||
inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
|
||||
inputs = [torch.randn(2, 2, dtype=torch.float, device="cuda")]
|
||||
ge = self.checkScript(should_fuse, inputs)
|
||||
self.assertAllFused(ge.graph_for(*inputs))
|
||||
|
||||
inputs = [
|
||||
torch.randn(2, 2, dtype=torch.float, device='cuda'),
|
||||
torch.tensor(3., dtype=torch.float, device='cuda'),
|
||||
torch.randn(2, 2, dtype=torch.float, device="cuda"),
|
||||
torch.tensor(3.0, dtype=torch.float, device="cuda"),
|
||||
]
|
||||
ge = self.checkScript(should_not_fuse, inputs)
|
||||
self.assertGraphContainsExactly(
|
||||
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
|
||||
ge.graph_for(*inputs), "prim::FusionGroup", 0, consider_subgraphs=True
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
|
||||
@enable_cpu_fuser
|
||||
@ -942,22 +1072,33 @@ class TestFuser(JitTestCase):
|
||||
y = torch.randn(4, 4, dtype=torch.double)
|
||||
|
||||
script_f = self.checkScript(f, (x, y))
|
||||
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
|
||||
self.assertAllFused(
|
||||
script_f.graph_for(x, y), except_for={"prim::TupleConstruct"}
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on"
|
||||
)
|
||||
def test_grad_sum_to_size_elimination(self):
|
||||
|
||||
def my_broadcasted_cell(a, b, c):
|
||||
return (a + b) + c
|
||||
|
||||
s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
|
||||
s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
|
||||
s1 = torch.randn(5, 1, requires_grad=True, device="cuda")
|
||||
s2 = torch.randn(5, 5, requires_grad=True, device="cuda")
|
||||
|
||||
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING)
|
||||
module = self.checkScript(
|
||||
my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING
|
||||
)
|
||||
forward_graph = module.graph_for(s1, s1, s1)
|
||||
self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal"))
|
||||
self.assertAllFused(
|
||||
forward_graph,
|
||||
except_for=(
|
||||
"aten::size",
|
||||
"prim::BroadcastSizes",
|
||||
"aten::_size_if_not_equal",
|
||||
),
|
||||
)
|
||||
|
||||
old_plans = set()
|
||||
for i in range(3):
|
||||
@ -966,7 +1107,9 @@ class TestFuser(JitTestCase):
|
||||
args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
|
||||
args = [a.detach_().requires_grad_() for a in args]
|
||||
# recompile, so we don't trigger bailouts
|
||||
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING)
|
||||
module = self.checkScript(
|
||||
my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING
|
||||
)
|
||||
res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
|
||||
warmup_backward(res.sum(), args)
|
||||
grads = torch.autograd.grad(res.sum(), args)
|
||||
@ -981,8 +1124,17 @@ class TestFuser(JitTestCase):
|
||||
backward = g
|
||||
old_plans.add(str(backward))
|
||||
num_grads = 1 if i > 0 else 0
|
||||
self.assertEqual(len([n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size']), num_grads)
|
||||
self.assertEqual(
|
||||
len(
|
||||
[
|
||||
n
|
||||
for n in backward.nodes()
|
||||
if n.kind() == "aten::_grad_sum_to_size"
|
||||
]
|
||||
),
|
||||
num_grads,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user