Files
pytorch/test/test_jit_fuser_te.py
2025-01-22 04:48:28 +00:00

3048 lines
107 KiB
Python

# Owner(s): ["NNC"]
# ruff: noqa: F841
import contextlib
import math
import operator
import os
import unittest
import warnings
import torch
import torch.nn.functional as F
from torch.testing import FileCheck
# these needs to be set before `common_utils`
# infers `GRAPH_EXECUTOR`.
# this file **requires** these settings
# and setting them after `GRAPH_EXECUTOR` is
# inferred erroneously runs or skips
# some tests
torch._C._jit_set_profiling_executor(True)
torch._C._get_graph_executor_optimize(True)
from itertools import combinations, permutations, product
from textwrap import dedent
from jit.test_fuser_common import TestFuserCommon # noqa: F401
from test_jit import (
backward_graph,
get_lstm_inputs,
get_milstm_inputs,
LSTMCellC,
LSTMCellF,
LSTMCellS,
MiLSTMCell,
)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
OpDTypes,
ops,
)
from torch.testing._internal.common_jit import JitCommonTestCase
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
GRAPH_EXECUTOR,
IS_FBCODE,
ProfilingMode,
run_tests,
skipIfTorchDynamo,
slowTest,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
)
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
from torch.testing._internal.jit_utils import (
clone_inputs,
get_traced_sample_variant_pairs,
JitTestCase,
NoTracerWarnContextManager,
RUN_CUDA,
RUN_CUDA_HALF,
RUN_CUDA_MULTI_GPU,
set_fusion_group_inlining,
TensorExprTestOptions,
warmup_backward,
)
FUSION_GROUP = "prim::TensorExprGroup"
LLVM_ENABLED = torch._C._llvm_enabled()
autograd_check_set = {
"aten::__is__",
"prim::AutogradAllNonZero",
"prim::AutogradAllZero",
"prim::ListConstruct",
}
def strip_profiling_nodes(nodes):
profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"}
return [n for n in nodes if n.kind() not in profiling_opcodes]
def warmup_forward(f, *args, profiling_count=2):
for i in range(profiling_count):
results = f(*args)
return results
@contextlib.contextmanager
def texpr_reductions_enabled():
old = torch._C._jit_set_texpr_reductions_enabled(True)
try:
yield
finally:
torch._C._jit_set_texpr_reductions_enabled(old)
@contextlib.contextmanager
def texpr_enable_strategy(strategy):
old = torch._C._jit_set_fusion_strategy(strategy)
try:
yield
finally:
torch._C._jit_set_fusion_strategy(old)
@contextlib.contextmanager
def inline_fusion_groups():
old_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(True)
try:
yield
finally:
torch._C._debug_set_fusion_group_inlining(old_inlining)
class TestTEFuser(JitTestCase):
def setUp(self):
super().setUp()
self.tensorexpr_options = TensorExprTestOptions()
# note: `self.dynamic_shapes` instatiated in specialization of class
# defined below
fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)]
self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy)
self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
self.int_dtypes = [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
self.fp_dtypes = [
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
]
self.dtypes = self.int_dtypes + self.fp_dtypes
def tearDown(self):
self.tensorexpr_options.restore()
torch._C._jit_set_fusion_strategy(self.old_fusion_strategy)
super().tearDown()
def assertAllFused(self, graph, except_for=None):
except_for = except_for if except_for is not None else set()
# TODO - upstream
guards = (
"prim::TypeCheck",
"prim::RequiresGradCheck",
"prim::TensorExprDynamicGuard",
)
guard_found = False
def autodiff_guard(node):
if node.kind() != "aten::all":
return False
inps = list(node.inputs())
if len(inps) != 1 or inps[0].node().kind() != "prim::ListConstruct":
return False
li_inps = list(inps[0].node().inputs())
for li_inp in li_inps:
if li_inp.node().kind() in (
"prim::AutogradAllNonZero",
"prim::AutogradAllZero",
):
return True
return False
def is_guard(node):
return node.kind() in guards or autodiff_guard(node)
for node in graph.block().nodes():
if node.kind() == "prim::Constant":
continue
if is_guard(node):
self.assertFalse(guard_found)
guard_found = True
continue
if node.kind() in except_for:
continue
if node.kind() == "prim::If":
self.assertTrue(is_guard(node.prev()))
continue
self.assertTrue(False, "Found unexpected node:" + node.kind())
self.assertTrue(guard_found)
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())
def findFusionGroups(self, graph):
result = []
for n in graph.nodes():
if n.kind() == FUSION_GROUP:
result.append(n.g("Subgraph"))
continue
for block in n.blocks():
result += self.findFusionGroups(block)
return result
def test_typecheck(self):
a = torch.ones(1)
def fused_kernel(a, b):
return (a + b) * 2.0
scripted = self.checkScript(fused_kernel, (a, a))
graph = scripted.graph_for(a, a)
# double check we fused
fusion_groups = self.findFusionGroups(graph)
self.assertEqual(len(fusion_groups), 1)
# we use a bigger tensor now (size 2)
# if we won't trigger a recompilation
# we will still create a tensor up to (size 1)
# if the type check fails
a = torch.ones(2)
# shape changed if we don't trigger recompilation
# we would compute the wrong result silently
self.assertEqual(scripted(a, a), fused_kernel(a, a))
def test_sum_simple(self):
def func(x):
x2 = x * x
return x2.sum()
with texpr_reductions_enabled():
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
a = a.reshape(5, 3)
scripted = self.checkScript(func, (a,))
self.assertLastGraphAllFused()
def test_nop(self):
pass
def test_sum_dim(self):
def func(x):
return x.sum((0,)) * 2
def func_neg(x):
return x.sum((-2,)) * 2
with texpr_reductions_enabled():
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
a = a.reshape(5, 3)
scripted = self.checkScript(func, (a,))
self.assertLastGraphAllFused()
scripted = self.checkScript(func_neg, (a,))
self.assertLastGraphAllFused()
def test_sum_keepdim_cast(self):
def func(x):
return x.sum((0,), keepdim=True, dtype=torch.double) * 2
with texpr_reductions_enabled():
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
a = a.reshape(5, 3)
self.checkScript(func, (a,))
self.assertLastGraphAllFused()
def test_abs(self):
for device in self.devices:
def func(x):
return x.abs() * 2
a = torch.randn(5, device=device)
scripted = self.checkScript(func, (a,))
self.assertLastGraphAllFused()
def test_unsqueeze_size_calculation(self):
for device in self.devices:
def foo(b, d):
x = d.unsqueeze(1)
y = x * 42.0
z = b + y
r = z / 42.0
return r
inputs = (
torch.rand(20, 28, device=device, requires_grad=True),
torch.rand(20, device=device),
)
scripted = self.checkScript(foo, inputs)
self.assertAllFused(scripted.graph_for(*inputs))
def test_zero_element_tensors(self):
for device in self.devices:
def decode(sin_t, cos_t):
theta = torch.atan2(sin_t.float(), cos_t.float())
return theta
sin = torch.zeros(0, device=device)
cos = torch.zeros(0, device=device)
inputs = [sin, cos]
ge = self.checkScript(decode, inputs)
def test_arg_configurations_smoke(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
# A smoke test to make sure we won't use the same kernel for contiguous
# and non-contiguous arguments.
# TODO: add optionally enabled debug counters to the fuser to verify
# that we really can tell the difference between configurations
for device in self.devices:
def f(x, y):
z1, z2 = (x + y).chunk(2, dim=1)
return z1 * z2
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
traced_f = torch.jit.trace(f, (x, y))
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
def test_broadcast(self):
for device in self.devices:
def scaleshift(x, scale, shift):
return x * scale + shift
inputs = [
torch.randn(4, 4, dtype=torch.float, device=device),
torch.randn(4, dtype=torch.float, device=device),
torch.randn(4, dtype=torch.float, device=device),
]
self.checkScript(scaleshift, inputs)
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
@unittest.skipIf(
GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on"
)
def test_cuda_half(self):
x = torch.randn(4, 4, dtype=torch.half, device="cuda")
y = torch.randn(4, 4, dtype=torch.half, device="cuda")
funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp]
# Note: Non fused inputs must be float to prevent loss of precision
inputs = (x.float(), y.float())
fusion_inputs = (x, y)
for fn in funcs:
local_inputs = [t.clone().requires_grad_() for t in inputs]
local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
# Verifies outputs
fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False)
outputs = fn(*local_inputs)
fusion_outputs = fusion(*local_fusion_inputs)
outputs_half = [t.half() for t in outputs]
self.assertEqual(outputs_half, fusion_outputs)
# Verifies gradients
for output, fusion_output in zip(outputs_half, fusion_outputs):
grads = torch.autograd.grad(
output.float().sum(),
local_inputs,
allow_unused=True,
retain_graph=True,
)
fusion_grads = torch.autograd.grad(
fusion_output.sum(),
local_fusion_inputs,
allow_unused=True,
retain_graph=True,
)
grads_half = [t.half() for t in grads]
self.assertEqual(grads_half, fusion_grads)
def test_checks_cat_inputs(self):
# single fusion node causes error
with set_fusion_group_inlining(True):
for device in self.devices:
# We shouldn't treat cat nodes as broadcasting. All their inputs
# need to be checked for having the same map size, before we can
# run the kernel.
def f(x, y):
return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0)
# NOTE: y is broadcastable to x, but output of f(x, y) should have
# shape 3x4, and not 4x4.
x = torch.randn(2, 4, dtype=torch.float, device=device)
y = torch.randn(1, 4, dtype=torch.float, device=device)
scripted = self.checkScript(f, (x, y))
self.assertEqual(scripted(x, y).shape, (3, 4))
self.assertAllFused(scripted.graph_for(x, y))
def test_chunk(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
for device in self.devices:
def fn(x):
a, b, c = x.chunk(3, 1)
return a * b + c
inputs = [torch.randn(10, 6, dtype=torch.float, device=device)]
self.checkScript(fn, inputs)
self.assertLastGraphAllFused()
def test_chunk_correctness(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
for device in self.devices:
def chunk_4_0(x):
x0, x1, x2, x3 = x.chunk(4, 0)
return x0 + x1 + x2 + x3
def chunk_4_1(x):
x0, x1, x2, x3 = x.chunk(4, 1)
return x0 + x1 + x2 + x3
def chunk_4_last(x):
x0, x1, x2, x3 = x.chunk(4, 2)
return x0 + x1 + x2 + x3
fns = [chunk_4_0, chunk_4_1, chunk_4_last]
tensors = [
# splitSize = 1
torch.randn(4, 4, 4, dtype=torch.float, device=device),
# contiguous case
torch.randn(12, 8, 16, dtype=torch.float, device=device),
# non-contiguous case
torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(
1, 2
),
]
for tensor in tensors:
for fn in fns:
self.checkScript(fn, [tensor])
self.assertLastGraphAllFused()
def test_chunk_distributes(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
for device in self.devices:
def f(x, y):
z1, z2 = (x + y).chunk(2, dim=1)
return z1 * z2
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(f, (x, y))
graph = ge.graph_for(x, y)
# XXX: The old fuser does broadcast_tensors but the new fuser doesn't.
# FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \
# .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
FileCheck().check("with " + FUSION_GROUP + "_").check_count(
"ConstantChunk", 1, exactly=True
).run(str(graph))
def test_chunk_motion_deduplicates_inputs(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
for device in self.devices:
def func1(x):
z = x * x
z0, z1 = z.chunk(2)
return z0 * z1
def func2(x):
z = x * x * x
z0, z1 = z.chunk(2)
return z0 * z1
inputs = [torch.tensor([1.1, 1.2], device=device, dtype=torch.float)]
for func in [func1, func2]:
self.checkScript(func, inputs)
self.assertLastGraphAllFused()
def test_chunk_multiple(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
for device in self.devices:
# The arguments are intentionally used out of order as a test to see
# if the fusion compiler adds extra args in the correct order
def fn(s, x, y, z):
z1, z2 = z.chunk(2, 2)
x1, x2, x3 = x.chunk(3, 1)
y1, y2 = y.chunk(2, 0)
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
inputs = [
torch.randn(5, 2, 3, dtype=torch.float, device=device),
torch.randn(5, 6, 3, dtype=torch.float, device=device),
torch.randn(10, 2, 3, dtype=torch.float, device=device),
torch.randn(5, 2, 6, dtype=torch.float, device=device),
]
ge = self.checkScript(fn, inputs)
self.assertAllFused(ge.graph_for(*inputs))
def test_minmax(self):
for device in self.devices:
def tmax(a, b):
return torch.max(2 * a, b)
def tmin(a, b):
return torch.min(2 * a, b)
a = torch.randn(4, 4, dtype=torch.float)
b = torch.randn(4, 4, dtype=torch.float)
nan = torch.tensor(float("nan"), dtype=torch.float)
for f, inputs, device in product(
(tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices
):
inputs = [t.to(device) for t in inputs]
s = self.checkScript(f, inputs)
self.assertAllFused(s.graph_for(*inputs))
def test_clamp(self):
for device in self.devices:
def func2(a, b):
return torch.clamp(a + b, min=0, max=2)
def funcInf(a, b):
return torch.clamp(a + b, min=0, max=float("inf"))
def funcNegInf(a, b):
return torch.clamp(a + b, min=float("-inf"), max=0)
def funcOptMin(a, b):
return torch.clamp(a + b, max=2)
def funcOptMax(a, b):
return torch.clamp(a + b, min=0)
a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True)
b = torch.randn(4, 4, dtype=torch.float, device=device)
nan = torch.tensor(float("nan"), dtype=torch.float, device=device)
funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
for f, inputs in product(funcs, [[a, b], [a, nan]]):
inp1, inp2 = inputs
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
self.assertAllFused(
s.graph_for(inp1, inp2),
except_for={"aten::size", "aten::_size_if_not_equal"},
)
c = s(inp1, inp2)
with enable_profiling_mode_for_profiling_tests():
warmup_backward(c.sum())
graph = backward_graph(s)
self.assertAllFused(
graph,
except_for={"aten::Float", "aten::_grad_sum_to_size"}.union(
autograd_check_set
),
)
def test_clamp_double(self):
for device in self.devices:
def clamp_double(x, eta: float):
return 1 - x.clamp(eta, 1 - eta)
x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device)
eta = 1e-9
s = self.checkScript(
clamp_double,
(x, eta),
profiling=ProfilingMode.PROFILING,
atol=1e-10,
rtol=1e-5,
)
self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"})
def test_clamp_int(self):
for device in self.devices:
def clamp_int(x, eta: int):
return x.clamp(0, eta)
x = torch.tensor([1, 1], device=device)
eta = 1 << 32
s = self.checkScript(clamp_int, (x, eta), profiling=ProfilingMode.PROFILING)
self.assertAllFused(s.graph_for(x, eta))
def test_add_bool(self):
sizes = [(1,), (2,), (4, 4)]
for device, size in product(self.devices, sizes):
def f(x, y, z):
return x + y + z
x = torch.randint(0, 2, size, dtype=torch.bool, device=device)
y = torch.randint(0, 2, size, dtype=torch.bool, device=device)
z = torch.randint(0, 2, size, dtype=torch.bool, device=device)
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
self.assertAllFused(ge.graph_for(x, y, z))
def test_mul_bool(self):
for device in self.devices:
def f(x, y, z):
return x * y * z
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
self.assertAllFused(ge.graph_for(x, y, z))
def test_div_bool(self):
for device in self.devices:
def f(x, y, z):
return (x + y) / z
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
z = torch.ones_like(x, dtype=torch.bool, device=device)
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
self.assertAllFused(ge.graph_for(x, y, z))
def test_bitwise_ops(self):
def apply(fn):
return lambda x, y, z: fn(fn(x, y), z)
binary_ops = [
operator.__and__,
operator.__or__,
operator.__xor__,
operator.__lshift__,
operator.__rshift__,
]
devices = self.devices
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
z = self.data_for(dtype, device)
fn = apply(op)
ref = fn(x, y, z)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_minmax_int_ops(self):
def apply(fn):
return lambda x, y, z: fn(fn(x, y), z)
binary_ops = [torch.min, torch.max]
devices = self.devices
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
z = self.data_for(dtype, device)
fn = apply(op)
ref = fn(x, y, z)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_comparison_eq_ne(self):
for device in self.devices:
def f(x, y):
mask = (x == 0).type_as(x)
z = x * mask + y
mask = (x != 0).type_as(x)
z = z * mask + y
return z
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(f, (x, y))
self.assertAllFused(ge.graph_for(x, y))
@staticmethod
def fn_test_comparison_gt_lt(x, y):
mask = (x > 0).type_as(x)
z = x * mask + y
mask = (x < 0).type_as(x)
z = z * mask + y
return z
def test_comparison_gt_lt(self):
for device in self.devices:
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
self.assertAllFused(ge.graph_for(x, y))
def test_comparison_ge_le(self):
for device in self.devices:
def f(x, y):
mask = (x >= 0).type_as(x)
z = x * mask + y
mask = (x <= 0).type_as(x)
z = z * mask + y
return z
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(f, (x, y))
self.assertAllFused(ge.graph_for(x, y))
x.requires_grad_(True)
y.requires_grad_(True)
self.assertAllFused(
ge.graph_for(x, y),
except_for=(
"aten::size",
"prim::BroadcastSizes",
"aten::_size_if_not_equal",
),
)
def test_addcmul(self):
for device in self.devices:
t = torch.randn(1, 4, dtype=torch.float, device=device)
t1 = torch.randn(4, 1, dtype=torch.float, device=device)
t2 = torch.randn(1, 4, dtype=torch.float, device=device)
def foo(t, t1, t2):
return t.addcmul(t + 1, t2, value=0.1)
ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
graph = ge.graph_for(t, t1, t2)
fusion_groups = self.findFusionGroups(graph)
self.assertEqual(len(fusion_groups), 1)
FileCheck().check("aten::add(").check("aten::addcmul(").run(
str(fusion_groups[0])
)
# TODO: We leak CUDA memory here because the traced graph holds onto a
# constant-ified tensor. Since the Python-global CompilationUnit is alive
# until the end of the process, the memory is effectively leaked.
# Removed `_cuda` suffix from this test which disables leak-checking.
# If this is a real problem, we'll need to revisit Torchscript Function
# lifetimes in Python.
def test_lerp(self):
for device in self.devices:
start = torch.randn(4, 1, dtype=torch.float, device=device)
end = torch.randn(1, 4, dtype=torch.float, device=device)
weight = torch.tensor(0.5, dtype=torch.float, device=device)
# scalar weight overload
def foo_weight_scalar(start, end):
return torch.lerp(start + 1, end, 0.5)
# tensor weight overload
def foo_weight_tensor(start, end):
return torch.lerp(start + 1, end, weight)
ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
graph = ge_weight_scalar.graph_for(start, end)
self.assertAllFused(graph)
# TODO: uncomment when TE enables support for scalar tensors
# ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
# graph = ge_weight_tensor.graph_for(start, end)
# self.assertAllFused(graph)
def test_concat(self):
# disabling concat causes error with single concat node
with set_fusion_group_inlining(True):
for device in self.devices:
hx = torch.randn(3, 20, dtype=torch.float, device=device)
cx = torch.randn(3, 20, dtype=torch.float, device=device)
def foo(hx, cx):
return torch.cat((hx + cx, hx * cx))
ge = self.checkTrace(foo, (hx, cx))
graph = ge.graph_for(hx, cx)
self.assertAllFused(graph)
# XXX: TE fuser can handle concats in a fusion group.
# FileCheck().check("FusedConcat").check_next("return").run(str(graph))
def test_remove_output_used_only_in_size(self):
for device in self.devices:
def test_fuse(a, b):
c = a + b
d = c + b
return d
scripted_f = torch.jit.script(test_fuse)
x = torch.ones(1, requires_grad=True, device=device)
y = torch.ones(1, requires_grad=True, device=device)
warmup_forward(scripted_f, x, y, profiling_count=3)
g = scripted_f.graph_for(x, y)
diff_nodes = g.findAllNodes("prim::DifferentiableGraph")
self.assertEqual(len(diff_nodes), 1)
g = diff_nodes[0].g("Subgraph")
if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"]
self.assertEqual(len(if_nodes), 1)
# the if node and the fusion group inside it should only have one output
self.assertEqual(len(list(if_nodes[0].outputs())), 1)
def test_concat_invariant(self):
for device in self.devices:
# Invariant: the output of prim::FusedConcat may
# not be an input to any node inside the FusionGroup.
def fn(x, y, z):
x1 = x + y
y1 = x - y
w = torch.cat([x1, y1])
return w + z
x = torch.randn(2, 2, dtype=torch.float, device=device)
y = torch.randn(2, 2, dtype=torch.float, device=device)
z = torch.randn(4, 2, dtype=torch.float, device=device)
ge = self.checkTrace(fn, (x, y, z))
graph = ge.graph_for(x, y, z)
self.assertAllFused(graph, except_for={"aten::add"})
# XXX: TE fuser can handle concats inside a fusion group.
# FileCheck().check("FusedConcat").check_next("return").run(str(graph))
@staticmethod
def fn_test_exp(x, y):
return (x + 0.5 * y).exp()
def test_exp(self):
for device in self.devices:
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(self.fn_test_exp, (x, y))
self.assertAllFused(ge.graph_for(x, y))
def test_threshold(self):
for device in self.devices:
def f(x):
return torch.threshold(x, 0, -10) + x + x + x
x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device)
scripted = self.checkScript(f, (x,))
self.assertAllFused(scripted.graph_for(x))
def test_scalar_arg(self):
for device in self.devices:
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
return p * (x * x + x)
x = torch.randn(4, 4, dtype=torch.float, device=device)
p = 3
scripted = self.checkScript(fn_test_scalar_arg, (x, p))
self.assertAllFused(scripted.graph_for(x, p))
x.requires_grad_(True)
# use another function otherwise we will bailout
# and won't be able to do fused checks
def fn_test_scalar_arg_requires_grad(
x: torch.Tensor, p: float
) -> torch.Tensor:
return p * (x * x + x)
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
out = scripted(x, p)
out = scripted(x, p)
out = scripted(x, p)
self.assertAllFused(
scripted.graph_for(x, p),
except_for=(
"aten::size",
"prim::BroadcastSizes",
"aten::_size_if_not_equal",
),
)
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
def test_fusion_reuse_multi_gpu(self):
def fn(x, y):
return x * y * x * y
inputs_cpu = [
torch.randn(4, 4, dtype=torch.float),
torch.randn(4, 4, dtype=torch.float),
]
inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
# Should not crash; these should compile different kernels.
ge = self.checkScript(fn, inputs_cpu)
self.assertAllFused(ge.graph_for(*inputs_cpu))
ge(*inputs_cuda0)
ge(*inputs_cuda1)
# TODO: we're currently not checking 'device' in the type info when pulling
# nodes into a fusion group. We should fix that and re-enable this test.
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
def test_kernel_cache_multi_gpu(self):
def not_fusible(x):
return x
def fn(x, y, z):
x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x
y_out = y * y * y * y * y
z_out = z * z * z * z * z
return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
inputs = [
torch.randn(4, 4, dtype=torch.float),
torch.randn(4, 4, dtype=torch.float, device="cuda:0"),
torch.randn(4, 4, dtype=torch.float, device="cuda:1"),
]
prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
# There are 3 FusionGroups. Because they have the same graph, they
# should reuse the same KernelSpec in the KernelSpec cache.
ge = self.checkScript(fn, inputs)
self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True)
new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
# XXX: This assumes that the same kernel isn't already used by another test
# FIXME: Use the TE fuser's way of querying the cache.
# self.assertEqual(new_cache_size - prev_cache_size, 1)
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
def test_nonzero_device_cuda(self):
device = "cuda:" + str(1)
x = torch.tensor([0.4], dtype=torch.float, device=device)
y = torch.tensor([0.7], dtype=torch.float, device=device)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y) + x))
ge = self.checkTrace(doit, (x, y))
self.assertAllFused(ge.graph_for(x, y))
def test_lstm(self):
for device in self.devices:
inputs = get_lstm_inputs(device, training=True)
module = self.checkScript(LSTMCellS, inputs)
self.assertAllFused(
module.graph_for(inputs), except_for={"prim::TupleConstruct"}
)
def test_lstm_concat(self):
# single fusion node causes error
with set_fusion_group_inlining(True):
for device in self.devices:
inputs = get_lstm_inputs(device)
ge = self.checkTrace(LSTMCellC, inputs)
graph = ge.graph_for(*inputs)
except_nodes = {"prim::TupleConstruct", "aten::linear"}
# TODO... Chunk
if self.dynamic_shapes:
except_nodes = except_nodes.union(
{"aten::add", "prim::ConstantChunk"}
)
self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes)
# XXX: TE fuser can handle concats inside a fusion group.
# FileCheck().check("FusedConcat").check_next("return").run(str(graph))
def test_lstm_gates_permutations(self):
for device in self.devices:
# lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
# Test that any permutation of this will still result in one FusionGroup.
choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"]
template = dedent(
"""
def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
gates = {} + {} + {} + {}
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
return ingate * forgetgate * cellgate * outgate
"""
)
for permutation in permutations(choices, len(choices)):
code = template.format(*permutation)
scope = {}
exec(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
fusion_group_len = 2 if self.dynamic_shapes else 1
inputs = get_lstm_inputs(device, training=False)
self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs))
forward_graph = cu.cell.graph_for(*inputs)
self.assertGraphContainsExactly(
forward_graph, FUSION_GROUP, fusion_group_len
)
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
def test_lstm_traced(self):
for device in self.devices:
inputs = get_lstm_inputs(device)
ge = self.checkTrace(LSTMCellF, inputs)
graph = ge.graph_for(*inputs)
fusion_groups = self.findFusionGroups(graph)
# TODO: chunk
fusion_group_len = 2 if self.dynamic_shapes else 1
self.assertEqual(len(fusion_groups), fusion_group_len)
f = FileCheck()
if not self.dynamic_shapes:
f.check("Chunk")
f.check("aten::sigmoid").check("aten::tanh").run(
str(fusion_groups[0 if not self.dynamic_shapes else 1])
)
def test_milstm(self):
if self.dynamic_shapes:
self.skipTest("don't run conv with dynamic shapes")
for device in self.devices:
inputs = get_milstm_inputs(device, training=True)
module = self.checkScript(MiLSTMCell, inputs)
forward_graph = module.graph_for(*inputs)
# TODO: chunk
fusion_group_len = 2 if self.dynamic_shapes else 1
self.assertGraphContainsExactly(
forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True
)
FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next(
"return"
).check(FUSION_GROUP).run(str(forward_graph))
hy, cy = module(*inputs)
warmup_backward((hy + cy).sum())
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skip("rand_like is not supported yet")
def test_rand_cuda(self):
class M(torch.jit.ScriptModule):
__constants__ = ["d"]
def __init__(self) -> None:
super().__init__()
self.d = torch.device("cuda")
@torch.jit.script_method
def create(self, x):
return x * x + x + torch.rand_like(x)
x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda")
m = M()
out1 = m.create(x)
out2 = m.create(x)
self.assertNotEqual(out1, out2)
self.assertTrue(torch.all(out1 >= 0))
self.assertTrue(torch.all(out1 < 1))
self.assertTrue(torch.all(out2 >= 0))
self.assertTrue(torch.all(out2 < 1))
self.assertAllFused(m.create.graph_for(x))
@staticmethod
def fn_test_relu(x, y):
return F.relu(x + 0.5 * y)
def test_relu(self):
for device in self.devices:
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(self.fn_test_relu, (x, y))
self.assertAllFused(ge.graph_for(x, y))
def test_erf(self):
for device in self.devices:
# only enabled on gpu
if device == "cpu":
continue
def fn_test_erf(x):
return F.relu(torch.erf(x) - torch.erfc(x))
x = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
self.assertAllFused(ge.graph_for(x))
x.requires_grad_(True)
ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
self.assertAllFused(
ge.graph_for(x),
except_for=(
"aten::size",
"prim::BroadcastSizes",
"aten::_size_if_not_equal",
),
)
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skip("rand_like is not supported yet")
def test_rand_broadcast_cuda(self):
def fn_test_rand(x, y):
r = torch.rand_like(y)
return r * x + x
# If using profiling, a different function is needed to test different
# shapes, or we'll use a cached script.
def fn_test_rand2(x, y):
r = torch.rand_like(y)
return r * x * x
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
script_f = torch.jit.script(fn_test_rand)
warmup_forward(script_f, x, y)
out = script_f(x, y)
self.assertAllFused(script_f.graph_for(x, y))
x.requires_grad_(True)
out = script_f(x, y)
self.assertAllFused(
script_f.graph_for(x, y),
except_for=(
"aten::size",
"prim::BroadcastSizes",
"aten::_size_if_not_equal",
),
)
# test that broadcasting random produces correct results
x = torch.ones(4, 4, dtype=torch.float, device="cuda")
y = torch.ones(4, dtype=torch.float, device="cuda")
script_f = torch.jit.script(fn_test_rand2)
warmup_forward(script_f, x, y)
out = script_f(x, y)
self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out)
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skip("rand_like is not supported yet")
def test_rand_diamond(self):
def fn_test_diamond(x, y):
r = torch.rand_like(y)
a = x + r
b = y - r
return a + b
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
script_f = torch.jit.script(fn_test_diamond)
warmup_forward(script_f, x, y)
out = script_f(x, y)
self.assertEqual(out, x + y)
def test_scalar(self):
def fn(x, y):
return 2 * x + y
x = torch.tensor(0.1, dtype=torch.float, device="cpu")
y = torch.tensor(1, dtype=torch.float, device="cpu")
ge = self.checkScript(fn, (x, y))
self.assertAllFused(ge.graph_for(x, y))
def test_inlined_optimized_graph(self):
@torch.jit.script
def foo(x):
return torch.relu(x + x)
for _ in range(3):
foo(torch.rand([4, 4]))
for _ in range(3):
foo(torch.rand([10]))
for _ in range(3):
foo(torch.rand([2, 2, 2]))
g = torch.jit.last_executed_optimized_graph()
FileCheck().check_count("prim::If", 1, exactly=True).check(
"prim::TensorExpr"
).run(g)
torch._C._jit_pass_inline(g)
f = FileCheck()
for _ in range(3):
f.check("prim::If").check("prim::TensorExpr")
f.run(g)
def test_small_constant(self):
for device in self.devices:
def fn_test_small_constant(x, y):
return (1e-8 * x + 5e-9 * y) * 1e8
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(fn_test_small_constant, (x, y))
self.assertAllFused(ge.graph_for(x, y))
# Currently we don't pull constants into fusion groups, because in some
# cases it could remove the constant from the original graph and now our
# fusion group needs to return that constant for its other users.
# Instead of never pulling constants into the fusion group, we should just
# be more careful at how we rewrite its users.
# TODO: fix that and reenable the test.
def test_tensor_scalar_ops(self):
for device in self.devices:
def should_fuse(x):
z = 3.0
y = x + z
return x * y
def should_fuse_scalar(x, z):
y = x + int(z)
return x * y
inputs = [torch.randn(2, 2, dtype=torch.float, device=device)]
ge = self.checkScript(should_fuse, inputs)
graph = ge.graph_for(*inputs)
fusion_groups = self.findFusionGroups(graph)
self.assertEqual(len(fusion_groups), 1)
FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0]))
inputs = [
torch.randn(2, 2, dtype=torch.float, device=device),
torch.tensor(3.0, dtype=torch.float, device=device),
]
ge = self.checkScript(should_fuse_scalar, inputs)
# Check that the fused graph computes correct results when the scalar
# input changes.
inputs = [
torch.randn(2, 2, dtype=torch.float, device=device),
torch.tensor(7.0, dtype=torch.float, device=device),
]
self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs))
# The TE fuser supports fusion of non-constant scalars
self.assertGraphContainsExactly(
ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True
)
def test_where_and_typing(self):
for device in self.devices:
def f(x, y):
mask = x > y
res = torch.where(mask, x, y)
return mask, res
x = torch.randn(4, 4, dtype=torch.double, device=device)
y = torch.randn(4, 4, dtype=torch.double, device=device)
script_f = self.checkScript(f, (x, y))
self.assertAllFused(
script_f.graph_for(x, y), except_for={"prim::TupleConstruct"}
)
def test_disabled(self):
old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
torch._C._jit_override_can_fuse_on_cpu(False)
def fn(a):
return a**2 + a
x = torch.randn(4, dtype=torch.float, device="cpu")
s = self.checkScript(fn, (x,))
g = s.graph_for(x)
self.assertEqual(len(self.findFusionGroups(g)), 0)
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
def data_for(self, dtype, device="cuda", size=None):
if size is None:
v = torch.arange(1, 3, dtype=torch.float, device=device)
else:
v = torch.rand(*size, device=device)
if dtype == torch.bool:
return v > 2
elif dtype in [torch.qint8, torch.quint8, torch.qint32]:
return torch.quantize_per_tensor(v, 0.1, 1, dtype=dtype)
else:
return v.to(dtype)
def test_torch_to(self):
# test no op
@torch.jit.script
def foo(x):
return x.to(torch.float)
foo(torch.tensor([3.0], dtype=torch.float))
foo(torch.tensor([3.0], dtype=torch.float))
FileCheck().check_not("TensorExpr").run(
torch.jit.last_executed_optimized_graph()
)
# test not fusing non-const inputs
@torch.jit.script
def foo(x, dtype: int):
return x.to(dtype)
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
FileCheck().check_not("TensorExpr").run(
torch.jit.last_executed_optimized_graph()
)
# test not fusing to_pinned inputs
@torch.jit.script
def foo(x, dtype: int):
return x.to(pin_memory=True)
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
FileCheck().check_not("TensorExpr").run(
torch.jit.last_executed_optimized_graph()
)
# test across-device not supported
if torch.cuda.is_available():
@torch.jit.script
def foo(x):
return x.to(device="cuda")
foo(torch.tensor([3.0], dtype=torch.float))
foo(torch.tensor([3.0], dtype=torch.float))
FileCheck().check_not("TensorExpr").run(
torch.jit.last_executed_optimized_graph()
)
sizes = [(1, 4), (4, 4)]
# reuses cast impl, smaller dtype set for faster test
dtypes = [
torch.bool,
torch.int,
torch.float16,
torch.float32,
torch.float64,
]
class MyMod(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.dtype = dtype
def forward(self, x):
return x.to(self.dtype)
bad_dtypes = []
for dtype, output_dtype, device, size in product(
dtypes, dtypes, self.devices, sizes
):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
if dtype == output_dtype:
continue
x = self.data_for(dtype, device, size=size)
mod = MyMod(output_dtype)
ref = mod.forward(x)
# use freezing to make non-Tensor args to `to` constant
mod = torch.jit.freeze(torch.jit.script(mod.eval()))
warmup_forward(mod.forward, x)
self.assertEqual(ref, mod.forward(x))
self.assertLastGraphAllFused()
@unittest.skip("Temporarily disabled")
def test_masked_fill(self):
dtypes = [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
# torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
sizes = [(2,), (4, 4)]
for self_dtype, device, scalar_val, size in product(
dtypes, self.devices, [0.4, 3], sizes
):
input_v = self.data_for(self_dtype, device, size=size)
mask = self.data_for(torch.bool, device, size=size)
def fn(input_v, mask):
return torch.masked_fill(input_v, mask, scalar_val)
ref = fn(input_v, mask)
try:
t = torch.jit.trace(fn, (input_v, mask))
torch.testing.assert_close(ref, t(input_v, mask))
self.assertLastGraphAllFused()
except Exception as e:
raise RuntimeError(
" ".join(
[
"Failed:",
str(self_dtype),
op.__name__, # noqa: F821
device,
str(size),
]
)
) from e
def test_isnan(self):
x = torch.rand([4])
x[0] = float("nan")
inputs = [x, torch.tensor([float("nan"), 0.5])]
dtypes = [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
for inp, device, dtype in product(inputs, self.devices, dtypes):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
inp = inp.to(device=device, dtype=dtype)
try:
f = torch.jit.trace(lambda x: x.isnan(), (inp,))
warmup_forward(f, inp)
self.assertEqual(f(inp), inp.isnan())
self.assertLastGraphAllFused()
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), "isnan", device])
) from e
def test_gelu(self):
def apply(fn):
return lambda x, approximate: fn(x, approximate)
unary_ops = [
F.gelu,
]
sizes = [(1,), (2,), (4, 4)]
for dtype, op, device, size in product(
self.dtypes, unary_ops, self.devices, sizes
):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device, size=size)
cond = self.data_for(torch.bool, device)
fn = apply(op)
ref = fn(x, cond)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, cond))
torch.testing.assert_close(ref, t(x, cond))
self.assertAllFused(t.graph_for(x, cond))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
) from e
def test_unary_ops(self):
with torch._jit_internal._disable_emit_hooks():
def apply(fn):
return lambda x: fn(x)
unary_ops = [
torch.lgamma,
torch.sigmoid,
torch.reciprocal,
torch.neg,
torch.relu,
F.relu6,
torch.log,
torch.log10,
torch.log1p,
torch.log2,
torch.exp,
torch.expm1,
torch.erf,
torch.erfc,
torch.cos,
torch.sin,
torch.tan,
torch.acos,
torch.asin,
torch.cosh,
torch.sinh,
torch.atan,
torch.tanh,
F.hardtanh,
F.hardsigmoid,
F.hardswish,
F.softplus,
F.silu,
F.mish,
F.elu,
torch.sqrt,
torch.rsqrt,
torch.abs,
# TODO broken on int8 since
# https://github.com/pytorch/pytorch/pull/85144
# RuntimeError: Invalid integral op_type: 23
# torch.ceil,
# torch.floor,
# torch.round,
# torch.trunc,
torch.frac,
# TODO: broken on ROCm?
# F.hardshrink,
F.leaky_relu,
lambda x: torch.threshold(x, 0, -10),
# TODO: broken since type promotion was added
# lambda x: torch.clamp(x, -10, 10),
]
gpu_only = {torch.erf, torch.erfc}
sizes = [(1,), (2,), (4, 4)]
for dtype, op, device, size in product(
self.dtypes, unary_ops, self.devices, sizes
):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
# todo - re-enable. fails with .500
if dtype == torch.bfloat16 and op == torch.round:
continue
if op in gpu_only and device == "cpu":
continue
try:
x = self.data_for(dtype, device, size=size)
fn = apply(op)
ref = fn(x)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x,))
torch.testing.assert_close(ref, t(x))
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(
["Failed:", str(dtype), op.__name__, device, str(size)]
)
) from e
def test_binary_ops(self):
def apply(fn):
return lambda x, y: fn(x, y)
binary_ops = [
operator.__and__,
operator.__or__,
operator.__xor__,
torch.add,
torch.sub,
torch.mul,
torch.min,
torch.max,
lambda x, y: torch.lerp(x, y, 0.5),
torch.atan2,
torch.div,
torch.eq,
torch.ne,
torch.ge,
torch.gt,
torch.lt,
torch.fmod,
torch.remainder,
lambda x, y: y.type_as(x),
]
fp_only = [
torch.fmod,
torch.remainder,
]
devices = self.devices
for dtype, op, device in product(self.dtypes, binary_ops, devices):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
fn = apply(op)
ref = fn(x, y)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y))
self.assertEqual(ref, t(x, y))
if op not in fp_only or dtype.is_floating_point:
self.assertAllFused(t.graph_for(x, y))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_binary_scalar_ops(self):
def apply(fn):
return lambda x, y: fn(x, y)
ir_template = """
graph(%x : {dtype_x}, %y : {dtype_y}):
%z = {op}(%x, %y)
return (%z)"""
binary_ops = [
"aten::mul",
"aten::add",
"aten::sub",
"aten::div",
"aten::lt",
"aten::le",
"aten::eq",
"aten::ne",
"aten::gt",
"aten::ge",
"aten::__or__",
"aten::__xor__",
"aten::__and__",
"aten::__lshift__",
"aten::__rshift__",
]
dtypes = ["int", "float", "bool"]
values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]}
devices = self.devices
for dtype_x, dtype_y, op, device in product(
dtypes, dtypes, binary_ops, devices
):
code = ir_template.format(**locals())
# Interpret the graph
try:
graph = torch._C.parse_ir(code)
for x, y in product(values[dtype_x], values[dtype_y]):
ref = torch._C._jit_interpret_graph(graph, (x, y))
except Exception:
# If we can't interpret this IR, don't bother checking NNC.
continue
# Compile the graph
try:
k = torch._C._te.TensorExprKernel(graph)
except Exception as e:
raise RuntimeError(
" ".join(["Compilation failed:", device, str(code)])
) from e
# Run the graph
for x, y in product(values[dtype_x], values[dtype_y]):
ref = torch._C._jit_interpret_graph(graph, (x, y))
try:
res = k.run((x, y))
self.assertEqual(ref, res)
except Exception as e:
raise RuntimeError(
" ".join(
["Failed at runtime:", device, str(x), str(y), str(code)]
)
) from e
def test_matmul(self):
if self.dynamic_shapes:
self.skipTest("don't run conv with dynamic shapes")
def fn(x, y):
return torch.matmul(x, y)
devices = ["cpu"] # No cuda support for ext calls yet
sizes = [
[[128, 128], [128, 128]],
[[10, 10], [10, 10]],
[[1, 16], [16, 128]],
[[128], [128]],
[[128], [128, 128]],
[[3], [3]],
[[3, 4], [4]],
[[10, 3, 4], [4]],
[[10, 3, 4], [10, 4, 5]],
[[10, 3, 4], [4, 5]],
]
# Only 2D x 2D matrix multiply is supported. For non-supported sizes we
# still want to run results verification to test that we didn't
# accidentally fuse it, but we skip the 'is-fused' check.
# TODO: add support for other shape combinations and make this set empty:
skip_is_fused_check_sizes = [
"[[128], [128]]",
"[[128], [128, 128]]",
"[[3], [3]]",
"[[3, 4], [4]]",
"[[10, 3, 4], [4]]",
"[[10, 3, 4], [10, 4, 5]]",
"[[10, 3, 4], [4, 5]]",
]
for dtype, size, device in product(self.dtypes, sizes, devices):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
size_x, size_y = size
x = self.data_for(dtype, device, size=size_x)
y = self.data_for(dtype, device, size=size_y)
ref = fn(x, y)
except Exception as e:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y))
t(x, y)
self.assertEqual(ref, t(x, y))
if str(size) not in skip_is_fused_check_sizes:
self.assertAllFused(t.graph_for(x, y))
except Exception as e:
raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e
def test_binary_tensor_scalar_ops(self):
with torch._jit_internal._disable_emit_hooks():
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)
# FIXME: Fails in IR Eval: torch.int64 and_ cpu
binary_ops = [
operator.__and__,
operator.__or__,
operator.__xor__,
torch.add,
torch.sub,
torch.mul,
torch.eq,
torch.ne,
torch.ge,
torch.lt,
torch.gt,
]
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(
self.dtypes, binary_ops, devices, scalars
):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
ref = fn(x)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x))
self.assertEqual(ref, t(x))
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_binary_div_ops(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)
binary_ops = [
torch.div,
torch.remainder,
torch.fmod,
]
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, -2.0, -1] # skip 0
for dtype, op, device, scalar in product(
self.dtypes, binary_ops, devices, scalars
):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
ref = fn(x)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x))
self.assertEqual(ref, t(x))
except Exception as e:
raise RuntimeError(
f"Failed: {dtype} {op.__name__} {device} {scalar}"
) from e
def test_binary_pow(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)
dtypes = [
# FIXME: 'pow' fails with dtype=torch.float16/device=cuda/scalar=0
# torch.float16,
torch.float32,
torch.float64,
# torch.bool intentionally not included
]
binary_ops = [
torch.pow,
]
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(
dtypes, binary_ops, self.devices, scalars
):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
ref = fn(x)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x))
self.assertEqual(ref, t(x))
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_ternary_ops(self):
def apply(fn):
return lambda x, y, z: fn(x, y, z)
ternary_ops = [
torch.lerp,
torch.addcmul,
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
z = self.data_for(dtype, device)
fn = apply(op)
ref = fn(x, y, z)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_ternary_norm_ops(self):
def apply(fn):
return lambda x, y, z: fn(x, y, z)
ternary_ops = [
F.batch_norm,
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device, size=[5, 3, 128, 128])
y = self.data_for(dtype, device, size=[3])
z = self.data_for(dtype, device, size=[3])
fn = apply(op)
ref = fn(x, y, z)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
@unittest.skip(
"FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure"
)
def test_list_ops(self):
def apply(fn):
return lambda x, y, z: fn([x * x, y * y, z * z])
devices = self.devices
list_ops = [
torch.cat,
]
for dtype, op, device in product(self.dtypes, list_ops, devices):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device, size=[5, 4, 1, 7])
y = self.data_for(dtype, device, size=[5, 4, 1, 7])
z = self.data_for(dtype, device, size=[5, 4, 1, 7])
fn = apply(op)
ref = fn(x, y, z)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_where_ops(self):
def apply(fn):
return lambda cond, x, y: fn(cond, x, y)
ops = [
torch.where,
lambda cond, x, y: torch.where(cond, x, 3.1415),
lambda cond, x, y: torch.where(cond, 42, y),
]
devices = self.devices
for dtype, op, device in product(self.dtypes, ops, devices):
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
cond = self.data_for(torch.bool, device)
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
fn = apply(op)
ref = fn(cond, x, y)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (cond, x, y))
self.assertEqual(ref, t(cond, x, y))
self.assertAllFused(t.graph_for(cond, x, y))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
) from e
def test_unsupported_dtypes(self):
for device in self.devices:
def fn(x):
return x * x + x
unsupported_dtypes = [
torch.uint8,
torch.complex32,
torch.complex64,
torch.complex128,
torch.qint8,
torch.quint8,
torch.qint32,
]
for dtype in unsupported_dtypes:
try:
x = self.data_for(dtype, device)
ref = fn(x)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
t = torch.jit.trace(fn, (x,))
self.assertEqual(ref, t(x))
self.assertEqual(len(self.findFusionGroups(t.graph_for(x))), 0)
def test_superslomo(self):
devices = self.devices.copy()
if not LLVM_ENABLED:
devices.remove("cpu")
for device in devices:
# Test extracted from Super-SloMo: https://github.com/avinashpaliwal/Super-SloMo
# A few interesting things happen here: strided inputs of mixed size,
# plus outputs of mixed shapes. The latter characteristic happened to
# expose a memory corruption bug due to not properly guarding the
# outputs.
def eager(t0, t1, t2, t3, t4):
t5 = torch.mul(t0, t4)
t6 = torch.mul(t2, t3)
t7 = torch.mul(t6, t1)
t9 = torch.add(t5, t7)
t11 = torch.add(t0, t6)
ft_p = torch.div(t9, t11)
return (ft_p, t11, t9, t6)
t0 = torch.rand(1, 6, 352, 352, device=device).transpose(0, 1)
t1 = torch.rand(6, 3, 352, 352, device=device)
t2 = torch.rand(6, device=device)[None, None, None, :].permute(3, 0, 1, 2)
t3 = torch.rand(6, 1, 352, 352, device=device)
t4 = torch.rand(6, 3, 352, 352, device=device)
inputs = [t0, t1, t2, t3, t4]
script = torch.jit.script(eager)
for _ in range(4):
for pair in zip(script(*inputs), eager(*inputs)):
test, ref = pair
torch.testing.assert_close(test, ref)
self.assertAllFused(
script.graph_for(*inputs), except_for={"prim::TupleConstruct"}
)
def test_sub_gt_and(self):
for device in self.devices:
def eager(t1, t2, t3, t4, t: float):
w = t1 - t2
h = t3 - t4
k = (w > t) & (h > t)
assert k.dtype == torch.bool
if t > 0.5:
# Putting a use of k in a never-executed conditional prevents
# profiling its type, which leaves it as "Tensor". If we
# propagate Tensor back to the definition of k, we have to be
# careful not to create a fusion group containing it.
return k + 1
return w
t = torch.rand(8, dtype=torch.float, device=device)
scripted = self.checkScript(eager, (t, t, t, t, 0.1))
@skipIfTorchDynamo("too slow")
def test_chunk_mul_one(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
for device in self.devices:
def eager(x):
z, y, w = torch.chunk(x, 3, -1)
return z * 3, y, w
x = torch.rand(64, 1, 3072, dtype=torch.float, device=device)
z, y, w = eager(x)
script = self.checkScript(eager, (x,))
def test_eq_unsqueeze_type_as(self):
for device in self.devices:
def eager(a, b):
mask = b == 1
mask = torch.unsqueeze(mask, -1)
x = mask.type_as(a)
return x, mask
a = torch.rand(1, 64, 1024, device=device, dtype=torch.float)
b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long)
script = self.checkScript(eager, (a, b))
def test_neg_pow(self):
def eager_tt(a: torch.Tensor, b: torch.Tensor):
return torch.neg(torch.pow(a, b))
def eager_ts(a: torch.Tensor, b: float):
return torch.neg(torch.pow(a, b))
def eager_st(a: float, b: torch.Tensor):
return torch.neg(torch.pow(a, b))
a = torch.rand(1, dtype=torch.float)
b = torch.rand(1, dtype=torch.float)
s = b.item()
script = self.checkScript(eager_tt, (a, b))
# TODO: re-enable fusion, which doesn't work right now. just test correctness for now
# self.assertAllFused(script.graph_for(a, b))
script = self.checkScript(eager_ts, (a, s))
# self.assertAllFused(script.graph_for(a, s))
script = self.checkScript(eager_st, (s, b))
# self.assertAllFused(script.graph_for(s, b))
@unittest.skipIf(not LLVM_ENABLED, "Too slow to run with the TE interpreter")
def test_conv2d_depthwise(self):
if self.dynamic_shapes:
self.skipTest("don't run conv with dynamic shapes")
def eager(input, weight, bias):
return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=72)
input = torch.rand((1, 72, 56, 56), dtype=torch.float)
weight = torch.rand((72, 1, 3, 3), dtype=torch.float)
bias = torch.rand((72), dtype=torch.float)
script = self.checkScript(eager, (input, weight, bias))
self.assertAllFused(script.graph_for(input, weight, bias))
def test_conv2d(self):
if self.dynamic_shapes:
self.skipTest("don't run conv with dynamic shapes")
def eager(input, weight, bias):
return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=1)
input = torch.rand((1, 64, 56, 56), dtype=torch.float)
weight = torch.rand((64, 64, 3, 3), dtype=torch.float)
bias = torch.rand((64), dtype=torch.float)
script = self.checkScript(eager, (input, weight, bias))
FileCheck().check_not("TensorExpr").run(
torch.jit.last_executed_optimized_graph()
)
def test_type_as_cat(self):
with inline_fusion_groups():
def eager(x, y):
return torch.cat((x, y.type_as(x)), dim=1)
dtypes = self.dtypes.copy()
# CPU fuser doesn't support float16.
dtypes.remove(torch.float16)
dtypes.remove(torch.bfloat16)
for dtype1, dtype2 in product(dtypes, dtypes):
x = torch.randint(2, (1, 13)).to(dtype1)
zero = torch.tensor([[0]]).to(dtype2)
one = torch.tensor([[1]]).to(dtype2)
script = torch.jit.trace(eager, (x, zero))
for _ in range(3):
torch.testing.assert_close(script(x, zero), eager(x, zero))
torch.testing.assert_close(script(x, one), eager(x, one))
self.assertAllFused(script.graph_for(x, one))
def test_to_device(self):
def eager(x):
return x.to(device="cpu").relu()
x = torch.rand(8)
script = self.checkScript(eager, (x,))
self.assertAllFused(script.graph_for(x))
def test_dims(self):
def eager(x, y):
return x / (y + 0.0001)
x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided(
(1, 1, 768), (768, 1, 1)
)
y = torch.tensor([[[2.0]]], dtype=torch.float32)
script = self.checkScript(eager, (x, y))
self.assertAllFused(script.graph_for(x, y))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_channels_last_dims_dynamic(self):
def eager(x, y):
return x + (y + 0.0001)
indices = [0, 1, 2, 3]
sets = []
for i in range(0, len(indices) + 1):
for subset in combinations(indices, i):
sets.append(subset) # noqa: PERF402
for set in sets:
size = [2, 3, 4, 5]
for index in set:
size[index] = 1
inp = torch.rand(size).to(memory_format=torch.channels_last).cuda()
with texpr_enable_strategy([("DYNAMIC", 20)]):
foo_s = torch.jit.trace(eager, (inp, inp))
for _ in range(3):
out = foo_s(inp, inp)
out_eager = eager(inp, inp)
self.assertEqual(out_eager, out)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
g = torch.jit.last_executed_optimized_graph()
FileCheck().check("TensorExpr").run(g)
def test_exhaust_specializations(self):
with texpr_enable_strategy([("STATIC", 1)]):
@torch.jit.script
def foo(x):
return x + x + x
for _ in range(3):
foo(torch.rand([2, 2]))
for _ in range(3):
foo(torch.rand([4, 4, 4]))
g = torch.jit.last_executed_optimized_graph()
torch._C._jit_pass_inline(g)
FileCheck().check_count("TensorExpr", 2, exactly=True).run(g)
def test_unsqueeze_var_dim(self):
def eager(x, y, z: int):
return x * torch.unsqueeze(y, dim=z)
x = torch.rand(4, 4, 64).permute(1, 0, 2)
y = torch.rand(4, 4)
z = 2
script = self.checkScript(eager, (x, y, z))
def _test_fwd_bwd(self, fn):
x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
script = torch.jit.script(fn)
for i in range(11):
y = fn(x)
g0 = torch.rand_like(y)
y.backward(g0)
ys = script(xs)
ys.backward(g0)
with torch.no_grad():
x -= 0.1 * x.grad
xs -= 0.1 * xs.grad
x.grad = None
xs.grad = None
torch.testing.assert_close(y, ys)
def test_relu_fwd_bwd(self):
def eager(x):
return torch.relu(x * 1.01)
self._test_fwd_bwd(eager)
def test_hardswish_fwd_bwd(self):
def eager(x):
return F.hardswish(x) * 1.01
self._test_fwd_bwd(eager)
def test_hardsigmoid_fwd_bwd(self):
def eager(x):
return F.hardsigmoid(x) * 1.01
self._test_fwd_bwd(eager)
def test_cat_graph_opt(self):
def foo(x, y, z):
return torch.log(torch.cat([x, y, z]))
self.checkScript(
foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5]))
)
# TODO: not sure why not updated graph isn't reflected in last_optimized_graph
self.assertLastGraphAllFused()
def test_dynamic_cat(self):
with inline_fusion_groups():
@torch.jit.script
def repro(
xs: list[torch.Tensor], ys: list[torch.Tensor], zs: list[torch.Tensor]
):
return [
torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1)
for x, y, z in zip(xs, ys, zs)
]
for _ in range(3):
N = 3
xs = [torch.ones(21) for _ in range(N)]
# Note: concat of ys and zs will have the same size for each
# pair, even though the individual ys and zs do not.
ys = [torch.ones(N - i) for i in range(N)]
zs = [torch.ones(i) for i in range(N)]
repro(xs, ys, zs)
def test_scalar_only_inputs(self):
def eager(b: float):
a = torch.ones(1)
return a * b
script = self.checkScript(eager, (1.0,))
def test_cat_2k_args(self):
with inline_fusion_groups():
def eager(x):
return torch.relu(torch.cat([x for _ in range(2000)]))
x = torch.randn(1)
trace = self.checkTrace(eager, (x,))
fusion_groups = self.findFusionGroups(trace.graph_for(x))
self.assertEqual(len(fusion_groups), 0)
def test_adaptive_avg_pool2d(self):
# TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this
# test should be moved there
with inline_fusion_groups():
def foo1(x):
return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2))
def foo2(x):
return torch.nn.functional.adaptive_avg_pool2d(x, (2))
x = torch.randn(4, 4, 4)
for foo in [foo1, foo2]:
f = torch.jit.trace(foo, (x,))
kernel = torch._C._te.TensorExprKernel(f.graph)
correct_val = f(x)
self.assertEqual(kernel.run((x,)), correct_val)
def test_unrolled_cat(self):
with inline_fusion_groups():
def eager(x):
ret = torch.empty(0)
for i in range(x.shape[0]):
ret = torch.cat([ret, x[i].relu()])
return ret
script = torch.jit.script(eager)
# Warm up with size=1 tensor; since the loop iterates once the
# profile data will be "burned in" assuming size=1, and then
# unrolled.
x = torch.ones(1, 1)
for _ in range(3):
script(x)
torch.testing.assert_close(eager(x), script(x))
# Now when an input hits the unrolled path, it will produce an
# incorrectly-sized tensor, since size=1 has been burned in.
x = torch.ones((8, 1))
torch.testing.assert_close(eager(x), script(x))
@skipIfTorchDynamo("too slow")
@unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
@unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans")
def test_batch_norm(self):
def test(fn, args):
trace = torch.jit.trace(fn, args)
self.assertAllFused(trace.graph_for(*args))
# TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the
# default?
torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True)
def bn(i, x):
return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu()
def bn_no_weight(i, x):
return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu()
def bn_no_bias(i, x):
return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu()
def bn_neither(i, x):
return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu()
for device in self.devices:
i = torch.randn(4, 16, 32, 40, device=device)
x = torch.randn(16, device=device)
for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
test(fn, (i, x))
def test_profiler(self):
@torch.jit.script
def test(x, y, z):
return x * y + z
args = [torch.randn(4) for _ in range(3)]
with torch.autograd.profiler.profile() as prof:
for _ in range(3):
test(*args)
self.assertIn("fused_mul_add", prof.table())
def test_skip_grad_in_check(self):
@torch.jit.script
def foo(x):
return (x + 2) / 2
inp = torch.rand([4, 4])
for _ in range(3):
foo(inp)
inp.requires_grad_(True)
with torch.inference_mode():
for _ in range(3):
foo(inp)
g = torch.jit.last_executed_optimized_graph()
torch._C._jit_pass_inline(g)
torch._C._jit_pass_inline(g)
FileCheck().check_count("prim::If", 1, exactly=True).run(g)
def test_dynamic_shapes(self):
from functools import partial
n = 10
gen_tensor = (
lambda n: R(1, n),
lambda n: R(n, n),
lambda n: R(n, n).transpose(0, 1),
lambda n: R(n + 1, n + 1, 2)[:n, n, 0],
lambda n: R(n, n, 2)[:, :, 0],
lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last),
)
with texpr_enable_strategy([("DYNAMIC", 20)]):
def foo(x, y, z):
return torch.sigmoid(torch.tanh(x))
foo.__disable_jit_function_caching__ = True
def fi(x, y, z):
return torch.tanh(x + y)
fi.__disable_jit_function_caching__ = True
def fum(x, y, z):
return torch.tanh(x + y) + z
fum.__disable_jit_function_caching__ = True
funcs = [foo, fi, fum]
with inline_fusion_groups():
for device in self.devices:
I = partial(torch.randint, 0, 100, device=device)
R = partial(torch.randn, device=device)
for i, func in enumerate(funcs):
num_args = i + 1
for j, gen in enumerate(gen_tensor):
inps = (gen(n), gen(n), gen(n))
func_s = torch.jit.trace(func, inps, check_trace=False)
torch._C._jit_pass_erase_shape_information(func_s.graph)
for _ in range(2):
x, y, z = gen(n), gen(n), gen(n)
func_s(x, y, z)
for incr in range(3):
func_s(*[gen(n + 1) for _ in range(3)])
g = torch.jit.last_executed_optimized_graph()
torch._C._jit_pass_inline(g)
torch._C._jit_pass_dce(g)
# We should see only one optimized kernel
FileCheck().check_count(
"TensorExprDynamicGuard", 1, exactly=True
).run(g)
self.assertEqual(func(*inps), func_s(*inps))
gen = gen_tensor[0]
inps = (gen(n), gen(n), gen(n))
foo_s = torch.jit.trace(foo, inps)
torch._C._jit_pass_erase_shape_information(foo_s.graph)
g_prev = None
for gen in gen_tensor:
for i in range(3):
foo_s(*[gen(n + i) for _ in range(3)])
inps = (gen(n), gen(n), gen(n))
self.assertEqual(foo_s(*inps), foo(*inps))
g = torch.jit.last_executed_optimized_graph()
torch._C._jit_pass_inline(g)
torch._C._jit_pass_dce(g)
FileCheck().check_count(
"TensorExprDynamicGuard", len(gen_tensor), exactly=True
).run(g)
@unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
def test_autocast_up(self):
def f(x):
y = x._autocast_to_full_precision(True, True)
z = torch.exp(y)
return z
x = torch.rand((2, 2), dtype=torch.half, device="cuda")
scr = torch.jit.script(f)
scr(x)
scr(x)
self.assertLastGraphAllFused()
@unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
def test_autocast_down(self):
def f(x):
y = torch.sigmoid(x)
z = y._autocast_to_reduced_precision(True, True, torch.half, torch.half)
return z
x = torch.rand((2, 2), dtype=torch.float, device="cuda")
scr = torch.jit.script(f)
scr(x)
scr(x)
self.assertLastGraphAllFused()
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
def test_to_dtype(self):
def f(x):
y = torch.sigmoid(x)
z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16)
h = z._autocast_to_full_precision(True, True)
i = h.to(dtype=torch.bfloat16)
j = i.to(dtype=torch.float32)
return j
x = torch.rand((2, 2), dtype=torch.float32)
scr = torch.jit.trace(f, x)
scr(x)
scr(x)
self.assertLastGraphAllFused()
self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3)
bf_x = torch.rand((2, 2), dtype=torch.bfloat16)
bf_scr = torch.jit.trace(f, bf_x)
bf_scr(bf_x)
bf_scr(bf_x)
graph = bf_scr.graph_for(bf_x)
fusion_groups = self.findFusionGroups(graph)
self.assertEqual(len(fusion_groups), 2)
self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3)
def test_with_strict_fusion(self):
def success(x):
with torch.jit.strict_fusion():
return x + x + x
scripted = self.checkScript(success, (torch.rand([4]),))
g = torch.jit.last_executed_optimized_graph()
FileCheck().check_not("aten::add").check("prim::TensorExprGroup").run(g)
def foo(x):
with torch.jit.strict_fusion():
return x + x + torch.rand([4]) + 3
with self.assertRaises(Exception) as error_out:
foo_s = torch.jit.script(foo)
foo_s(torch.rand([4]))
foo_s(torch.rand([4]))
print(torch.jit.last_executed_optimized_graph())
fc = FileCheck().check("Found unfused operators")
fc.check("aten::rand(SymInt[] size")
fc.check("torch.rand([4]").run(str(error_out.exception))
with warnings.catch_warnings(record=True) as warns:
foo(torch.rand([4]))
FileCheck().check("Only works in script mode").run(str(warns[0]))
def test_autodiff(x):
with torch.jit.strict_fusion():
return torch.rand([4]) + x + x + x
foo_s = torch.jit.script(test_autodiff)
inp = torch.rand([4], requires_grad=True)
with self.assertRaises(Exception) as error_out:
for _ in range(3):
foo_s(inp)
f = FileCheck().check("unfused operators").check("aten::rand")
f.run(str(error_out.exception))
def test_separate_fusions(x, y):
with torch.jit.strict_fusion():
return x + x + x, y + y + y
inp = torch.rand([4], requires_grad=True)
with self.assertRaises(Exception) as error_out:
for _ in range(3):
foo_s = torch.jit.script(test_separate_fusions)
foo_s(inp, inp)
f = FileCheck().check("Found multiple fusions")
f.run(str(error_out.exception))
def test_constant_chunk_shapes(self):
# We had an issue where buildShapeExpressions would fail as show below:
#
# %1 : Tensor = Constant[..] # not supported, we don't build this shape
# %2 : Tensor = Constant[..] # not supported
# %3 : Tensor = aten::add(%1, %2) # inputs not supported, we don't build shape
# ... = prim::ConstantChunk[..](%3) # it forgets to check whether input shapes exist, and fails
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
for device in self.devices:
def f(x, y):
r = torch.tensor(4)
z1, z2 = (x + y + r).chunk(2, dim=1)
return z1 * z2
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
ge = self.checkTrace(f, (x, y))
graph = ge.graph_for(x, y)
# make sure that we are actually testing the right scenario
FileCheck().check("with " + FUSION_GROUP + "_").check_count(
"ConstantChunk", 1, exactly=True
).run(str(graph))
f_traced = torch.jit.trace(f, (x, y))
for i in range(4):
# make sure this doesn't error out
res = f_traced(x, y)
self.assertEqual(res, f(x, y))
@unittest.skipIf(not RUN_CUDA_HALF, "half-precision NNC fusion requires CUDA")
def test_pow_multiple_dtype(self):
# https://github.com/pytorch/pytorch/issues/75476
def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
p = torch.sigmoid(p)
result = p**gamma
return result
x = torch.rand((2, 2), dtype=torch.half, device="cuda")
ref = fn(x)
script_fn = torch.jit.script(fn)
for i in range(4):
res = script_fn(x)
self.assertEqual(ref, res)
class TestTEFuserStatic(TestTEFuser):
dynamic_shapes = False
class TestTEFuserDynamic(TestTEFuser):
dynamic_shapes = True
del TestTEFuser
works_list = [
"__radd__",
"__rdiv__",
"__rmul__",
"__rmod__",
"abs",
"acos",
"add",
"addcmul",
"addmm.decomposed",
"asin",
"atan",
"atan2",
"ceil",
"clamp",
"clamp.scalar",
"contiguous",
"cos",
"cosh",
"div.no_rounding_mode",
"div.true_rounding",
"div.floor_rounding",
"div.trunc_rounding",
"eq",
"erf",
"erfc",
"exp",
"expand",
"expand_as",
"expm1",
"floor",
"fmod",
"fmod.autodiffed",
"ge",
"gt",
"isnan",
"le",
"lerp",
"lgamma",
"log",
"log10",
"log1p",
"log2",
"lt",
"masked_fill",
"max.binary",
"mean",
"min.binary",
"mm",
"mul",
"ne",
"neg",
"nn.functional.hardshrink",
"nn.functional.hardsigmoid",
"nn.functional.hardswish",
"nn.functional.softplus",
"nn.functional.hardtanh",
"nn.functional.leaky_relu",
"nn.functional.relu",
"nn.functional.relu6",
"nn.functional.softsign",
"nn.functional.tanhshrink",
"nn.functional.threshold",
"permute",
"pow",
"reciprocal",
"remainder",
"remainder.autodiffed",
"reshape",
"reshape_as",
"round",
"rsub",
"rsub.rsub_tensor",
"rsqrt",
"sigmoid",
"sign",
"sin",
"sinh",
"sqrt",
"sub",
"sum",
"t",
"tan",
"tanh",
"transpose",
"true_divide",
"trunc",
"unsqueeze",
"view",
"view_as",
"where",
"bool",
"byte",
"char",
"double",
"float",
"half",
"int",
"long",
"short",
"bool.channels_last",
"byte.channels_last",
"char.channels_last",
"double.channels_last",
"float.channels_last",
"half.channels_last",
"int.channels_last",
"long.channels_last",
"short.channels_last",
]
known_failures = [
"__rmatmul__",
"frac",
"matmul",
]
# If your OpInfo test causes this test to fail, add it here
skip_ops = ["conj"]
def get_name(op):
l = [op.name]
if op.variant_test_name != "":
l.append(op.variant_test_name)
return ".".join(l)
# Purpose of this class is to allow super() calls.
# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
class TestNNCOpInfoParent(JitCommonTestCase):
pass
class TestNNCOpInfo(TestNNCOpInfoParent):
def setUp(self):
super(TestNNCOpInfoParent, self).setUp()
self.tensorexpr_options = TensorExprTestOptions()
def tearDown(self):
self.tensorexpr_options.restore()
super(TestNNCOpInfoParent, self).tearDown()
def te_compile(self, device, dtype, op):
if op.name in skip_ops:
return
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
for sample_input in sample_inputs_itr:
arg_values = [sample_input.input] + list(sample_input.args)
kwarg_values = sample_input.kwargs
param_names = []
param_values = []
fx_args = []
for idx, v in enumerate(arg_values):
if isinstance(v, torch.Tensor):
param_names.append(f"arg_{idx}")
param_values.append(v)
fx_args.append(param_names[-1])
else:
fx_args.append(f"{repr(v)}")
for k, v in kwarg_values.items():
if isinstance(v, torch.Tensor):
param_names.append(k)
param_values.append(v)
fx_args.append(f"{k} = {k}")
else:
fx_args.append(f"{k} = {repr(v)}")
code = f"""
def f({', '.join(param_names)}):
return op.op({', '.join(fx_args)})"""
g = {"torch": torch, "inf": math.inf, "op": op}
exec(code, g)
f = g["f"]
f.__module__ = "test"
out = f(*param_values)
ts_g = torch.jit.trace(f, param_values)
kernel = torch._C._te.TensorExprKernel(ts_g.graph)
correct_val = f(*param_values)
self.assertEqual(kernel.run(tuple(param_values)), correct_val)
self.assertEqual(kernel.fallback(tuple(param_values)), correct_val)
@onlyCPU
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
@ops(
[op for op in op_db if get_name(op) in works_list],
allowed_dtypes=(torch.float,),
)
def test_working(self, device, dtype, op):
self.te_compile(device, dtype, op)
@onlyCPU
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
@ops(
[op for op in op_db if get_name(op) in known_failures],
allowed_dtypes=(torch.float,),
)
def test_failures(self, device, dtype, op):
try:
self.te_compile(device, dtype, op)
except Exception as e:
pass
else:
raise RuntimeError(
"Expected test to fail. If it now works, move op into works_list"
)
@onlyCPU
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
@ops(
[op for op in op_db if get_name(op) not in works_list + known_failures],
allowed_dtypes=(torch.float,),
)
def test_unsupported(self, device, dtype, op):
if get_name(op) in skip_ops:
return
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", TracerWarning) # noqa: F821
self.te_compile(device, dtype, op)
except Exception as e:
pass
else:
raise RuntimeError(
"Expected test to fail. If it now works, move op into works_list"
)
@slowTest
@onlyCPU
@ops(op_db, dtypes=OpDTypes.supported)
def test_nnc_correctness(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("Requires tracing support")
with NoTracerWarnContextManager() as no_warn:
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
trace = create_traced_fn(self, variant, cache_traced_fn=True)
ref = variant(
*clone_inputs((sample.input, *sample.args)), **sample.kwargs
)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
val = trace(
*clone_inputs((sample.input, *sample.args)), **sample.kwargs
)
atol = 2e-1 if dtype == torch.bfloat16 else 1e-5
rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5
self.assertEqual(ref, val, atol=atol, rtol=rtol)
# https://github.com/pytorch/pytorch/issues/35600
# each torch.jit.trace adds state to the _python_cu compilation unit
# since this test traces a lot of functions, out-of-memory can occur
# if the CU is not cleared.
torch.jit._state._python_cu.drop_all_functions()
# CPU fuser not currently used in fbcode
only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda")
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
class TestLoopnestRandomizationParent(JitTestCase):
pass
class TestLoopnestRandomization(TestLoopnestRandomizationParent):
def setUp(self):
super(TestLoopnestRandomizationParent, self).setUp()
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu()
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(True)
# TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle
# torch._C._jit_set_te_must_use_llvm_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(False)
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
torch._C._jit_set_texpr_fuser_enabled(True)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
# Set the seed to 1. This tests the codepath through random
# transformation.
os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "1"
def tearDown(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
torch._C._get_graph_executor_optimize(self.old_profiling_mode)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
# Set it back to 0.
os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "0"
super(TestLoopnestRandomizationParent, self).tearDown()
@onlyCPU
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
def test_relu(self, device):
def fn_test_relu(x, y):
return F.relu(x + 0.5 * y)
x = torch.randn(4, 4, dtype=torch.float, device=device)
y = torch.randn(4, 4, dtype=torch.float, device=device)
fn = fn_test_relu
traced_fn = torch.jit.trace(fn, (x, y))
ref = fn(x, y)
res = traced_fn(x, y)
assert torch.allclose(ref, res)
instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu"))
if __name__ == "__main__":
run_tests()