mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
287 lines
10 KiB
Python
287 lines
10 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
@skipIfTorchDynamo()
|
|
class TestProfiler(JitTestCase):
|
|
def setUp(self):
|
|
self.prev_exec = torch._C._jit_set_profiling_executor(True)
|
|
self.prev_profiling = torch._C._get_graph_executor_optimize(True)
|
|
self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False)
|
|
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
|
self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu()
|
|
torch._C._jit_set_texpr_fuser_enabled(True)
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
self.default_dtype = torch.get_default_dtype()
|
|
self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True)
|
|
torch.set_default_dtype(torch.double)
|
|
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
|
|
torch._C._debug_set_fusion_group_inlining(False)
|
|
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
|
|
torch._C._jit_set_te_must_use_llvm_cpu(False)
|
|
|
|
def tearDown(self):
|
|
torch._C._jit_set_profiling_executor(self.prev_exec)
|
|
torch._C._get_graph_executor_optimize(self.prev_profiling)
|
|
torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff)
|
|
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
|
|
torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)
|
|
torch.set_default_dtype(self.default_dtype)
|
|
torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
|
|
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
|
|
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
|
|
|
|
def test_tensor_type_not_determined_by_inputs(self):
|
|
@torch.jit.script
|
|
def scalar_type_input(x, y, z):
|
|
return x + y + 4 + z.item()
|
|
|
|
x = torch.tensor([2, 2])
|
|
scalar_type_input(x, x, torch.tensor(1))
|
|
scalar_type_input(x, x, torch.tensor(1))
|
|
scalar_type_input(x, x, torch.tensor(1.0))
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
|
|
# item & add should not get pulled into the fusion group -
|
|
# we expect to see Fusion Group (item / add) Fusion Group in ir dump
|
|
FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next(
|
|
"Tensor = aten::add"
|
|
).check("TensorExpr").run(g)
|
|
|
|
@torch.jit.script
|
|
def non_const_dtype(x, y, cond: bool):
|
|
dtype = torch.int16 if cond else torch.int32
|
|
return (x + y + 3).sum(dtype=dtype)
|
|
|
|
non_const_dtype(x, x, True)
|
|
non_const_dtype(x, x, True)
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
# because dtype is non-const, sum should not get pulled into the Fusion Group
|
|
FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(
|
|
g
|
|
)
|
|
|
|
def test_specialize_backward(self):
|
|
def test_fuse(a, b):
|
|
c = a * b
|
|
d = c * b
|
|
return d
|
|
|
|
test_fuse.__disable_jit_function_caching__ = True
|
|
|
|
scripted_f = torch.jit.script(test_fuse)
|
|
x = torch.ones(1, requires_grad=True)
|
|
y = torch.ones(1, requires_grad=True)
|
|
scripted_f(x, y)
|
|
b = scripted_f(x, y)
|
|
warmup_backward(b)
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
# Backward has an if node guarding specializations,
|
|
# within the if node true block there is only one if node
|
|
# that guards a tensorexpr group
|
|
optimized_block = next(g.findNode("prim::If").blocks())
|
|
if_nodes = list(optimized_block.findAllNodes("prim::If"))
|
|
|
|
self.assertEqual(len(if_nodes), 1)
|
|
FileCheck().check("Group[Subgraph").run(str(if_nodes[0]))
|
|
# no broadcasts occurred, sum_to_size have been specialized out
|
|
self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size"))
|
|
|
|
broadcast_f = torch.jit.script(test_fuse)
|
|
x = torch.ones([2, 2], requires_grad=True)
|
|
y = torch.ones([1], requires_grad=True)
|
|
broadcast_f(x, y)
|
|
b = broadcast_f(x, y)
|
|
b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True)
|
|
b.backward(torch.ones([2, 2], dtype=torch.float))
|
|
# warmup_backward(b, torch.ones([2, 2], dtype=torch.float))
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
optimized_block = next(g.findNode("prim::If").blocks())
|
|
# broadcasts occurred, currently expect to see aten::_grad_sum_to_size
|
|
self.assertIsNotNone(optimized_block.findNode("aten::_grad_sum_to_size"))
|
|
|
|
def test_specialized_types(self):
|
|
@torch.jit.script
|
|
def test_fuse(a, b):
|
|
c = a * b
|
|
d = c * b
|
|
return d
|
|
|
|
x = torch.tensor([0.5])
|
|
for _ in range(3):
|
|
test_fuse(x, x)
|
|
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
# Types should remain specialized for typecheck outputs & fusion outputs
|
|
FileCheck().check("Double(").check_same("prim::TypeCheck").check_same(
|
|
"\n"
|
|
).check("Double").check_same("TensorExpr").run(g)
|
|
|
|
# other outputs should not be specialized
|
|
FileCheck().check("Tensor = prim::If").run(g)
|
|
|
|
def test_aliasing_merge(self):
|
|
@torch.jit.script
|
|
def foo(a, b):
|
|
c = a * b
|
|
d = c * b
|
|
d.add_(b)
|
|
e = d * b
|
|
return d + e
|
|
|
|
x = torch.ones(1)
|
|
y = torch.ones(1)
|
|
foo(x, y)
|
|
b = foo(x, y) # noqa: F841
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2)
|
|
FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g)
|
|
|
|
def test_use_not_profiled(self):
|
|
def foo(t1, t2, t3, t4, t: float):
|
|
h = t1 + t2 + t3 + t4
|
|
if t > 0.5:
|
|
# Putting a use of t1 in a never-executed conditional prevents
|
|
return t1 + 1
|
|
return h
|
|
|
|
t = torch.rand(8, dtype=torch.float)
|
|
|
|
foo_script = torch.jit.script(foo)
|
|
for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
|
|
foo_script(t, t, t, t, 0.1)
|
|
|
|
self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1))
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
# all adds fused
|
|
FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g)
|
|
|
|
def test_not_fusing_scalar_ops(self):
|
|
@torch.jit.script
|
|
def foo(x: int, y: int):
|
|
return x + y + 2 + 4 + 5 + 6
|
|
|
|
foo(1, 2)
|
|
foo(2, 3)
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
FileCheck().check_not("TensorExpr").run(g)
|
|
|
|
def test_not_optimizing_property(self):
|
|
@torch.jit.script
|
|
def foo(x, y):
|
|
return x + y + 1 + 2 + 3, x.size()
|
|
|
|
x = torch.ones(1)
|
|
foo(x, x)
|
|
foo(x, x)
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
FileCheck().check("aten::size").run(g)
|
|
x = torch.ones([2, 3, 5])
|
|
self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
|
|
|
|
def test_fallback_graph_not_specialized(self):
|
|
@torch.jit.script
|
|
def foo(a, b):
|
|
c = a * b
|
|
d = c * b
|
|
e = d * b
|
|
return d + e
|
|
|
|
x = torch.ones(1)
|
|
y = torch.ones(1)
|
|
foo(x, y)
|
|
foo(x, y)
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(
|
|
g
|
|
)
|
|
|
|
def test_autograd_fallback_graph(self):
|
|
@torch.jit.script
|
|
def foo(a, b):
|
|
c = a * b
|
|
d = c * b
|
|
e = d * b
|
|
return d + e
|
|
|
|
x = torch.ones(1, requires_grad=True)
|
|
y = torch.ones(1, requires_grad=True)
|
|
foo(x, y)
|
|
b = foo(x, y)
|
|
b.backward(torch.ones([1], dtype=torch.float), retain_graph=True)
|
|
b.backward(torch.ones([1], dtype=torch.float))
|
|
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
FileCheck().check("fallback_function").check_next("CallFunction").run(g)
|
|
|
|
def test_tensor_constant(self):
|
|
def foo(a, b):
|
|
return a + b + torch.tensor([2])
|
|
|
|
x = torch.ones(1, requires_grad=False)
|
|
foo_script = torch.jit.script(foo)
|
|
foo_script(x, x)
|
|
foo_script(x, x)
|
|
|
|
self.assertEqual(foo_script(x, x), foo(x, x))
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
FileCheck().check_count("aten::add", 2, exactly=True).run(g)
|
|
|
|
def test_local_fusion_strategy(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return x + x + x
|
|
|
|
torch.jit.set_fusion_strategy([("STATIC", 1)])
|
|
for _ in range(3):
|
|
foo(torch.rand([10]))
|
|
|
|
torch.jit.set_fusion_strategy([("STATIC", 10)])
|
|
|
|
for i in range(10):
|
|
foo(torch.rand([i]))
|
|
foo(torch.rand([i]))
|
|
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g)
|
|
|
|
def test_iterative_fusion(self):
|
|
@torch.jit.script
|
|
def foo(a, b, c, d):
|
|
a = a + b
|
|
b.add_(3)
|
|
c = c + b + d
|
|
a = a + 1
|
|
return a, c
|
|
|
|
x = torch.ones(1, requires_grad=False)
|
|
foo(x, x, x, x)
|
|
foo(x, x, x, x)
|
|
|
|
# when we iterate through the block, we will start
|
|
# by fusing a = a + b with a = a + 1
|
|
# if we were to continue iteration from that fusion point,
|
|
# would miss the fusion opportunity of c = c + d + b
|
|
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2)
|