remove loop expects (#17695)

Summary:
Replace loop unrolling expect files with assertions on the output IR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17695

Differential Revision: D14347105

Pulled By: eellison

fbshipit-source-id: 1703b4ca32bc1c67c01fc4330b0e6eb66feaa103
This commit is contained in:
Elias Ellison
2019-03-06 11:42:19 -08:00
committed by Facebook Github Bot
parent b87abdfc12
commit 7fa996f8e2
5 changed files with 16 additions and 147 deletions

View File

@ -1,44 +0,0 @@
graph(%x : Tensor):
%1 : bool = prim::Constant[value=1]()
%y.1 : int = prim::Constant[value=0]()
%3 : int = prim::Int(%x)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=8]()
%6 : int = aten::__round_to_zero_floordiv(%3, %5)
%7 : int = prim::Constant[value=8]()
%8 : int = aten::mul(%6, %7)
%9 : int = aten::sub(%3, %8)
%10 : int, %y.3 : int = prim::Loop(%6, %1, %4, %y.1)
block0(%i.1 : int, %13 : int, %14 : int):
%y.12 : int = aten::add(%14, %13)
%16 : int = prim::Constant[value=1]()
%17 : int = aten::add(%13, %16)
%y.5 : int = aten::add(%y.12, %17)
%19 : int = prim::Constant[value=1]()
%20 : int = aten::add(%17, %19)
%y.6 : int = aten::add(%y.5, %20)
%22 : int = prim::Constant[value=1]()
%23 : int = aten::add(%20, %22)
%y.7 : int = aten::add(%y.6, %23)
%25 : int = prim::Constant[value=1]()
%26 : int = aten::add(%23, %25)
%y.8 : int = aten::add(%y.7, %26)
%28 : int = prim::Constant[value=1]()
%29 : int = aten::add(%26, %28)
%y.9 : int = aten::add(%y.8, %29)
%31 : int = prim::Constant[value=1]()
%32 : int = aten::add(%29, %31)
%y.10 : int = aten::add(%y.9, %32)
%34 : int = prim::Constant[value=1]()
%35 : int = aten::add(%32, %34)
%y.11 : int = aten::add(%y.10, %35)
%37 : int = prim::Constant[value=1]()
%38 : int = aten::add(%35, %37)
-> (%1, %38, %y.11)
%39 : int, %y : int = prim::Loop(%9, %1, %10, %y.3)
block0(%i : int, %42 : int, %43 : int):
%y.4 : int = aten::add(%43, %42)
%45 : int = prim::Constant[value=1]()
%46 : int = aten::add(%42, %45)
-> (%1, %46, %y.4)
return (%y)

View File

@ -1,14 +0,0 @@
graph():
%y.1 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=1]()
%y.11 : int = aten::add(%y.1, %1)
%y.2 : int = aten::add(%y.11, %1)
%y.3 : int = aten::add(%y.2, %1)
%y.4 : int = aten::add(%y.3, %1)
%y.5 : int = aten::add(%y.4, %1)
%y.6 : int = aten::add(%y.5, %1)
%y.7 : int = aten::add(%y.6, %1)
%y.8 : int = aten::add(%y.7, %1)
%y.9 : int = aten::add(%y.8, %1)
%y.10 : int = aten::add(%y.9, %1)
return (%y.10)

View File

@ -1,32 +0,0 @@
graph():
%y.1 : int = prim::Constant[value=0]()
%1 : int = prim::Constant[value=0]()
%y.11 : int = aten::add(%y.1, %1)
%3 : int = prim::Constant[value=1]()
%4 : int = aten::add(%1, %3)
%y.2 : int = aten::add(%y.11, %4)
%6 : int = prim::Constant[value=1]()
%7 : int = aten::add(%4, %6)
%y.3 : int = aten::add(%y.2, %7)
%9 : int = prim::Constant[value=1]()
%10 : int = aten::add(%7, %9)
%y.4 : int = aten::add(%y.3, %10)
%12 : int = prim::Constant[value=1]()
%13 : int = aten::add(%10, %12)
%y.5 : int = aten::add(%y.4, %13)
%15 : int = prim::Constant[value=1]()
%16 : int = aten::add(%13, %15)
%y.6 : int = aten::add(%y.5, %16)
%18 : int = prim::Constant[value=1]()
%19 : int = aten::add(%16, %18)
%y.7 : int = aten::add(%y.6, %19)
%21 : int = prim::Constant[value=1]()
%22 : int = aten::add(%19, %21)
%y.8 : int = aten::add(%y.7, %22)
%24 : int = prim::Constant[value=1]()
%25 : int = aten::add(%22, %24)
%y.9 : int = aten::add(%y.8, %25)
%27 : int = prim::Constant[value=1]()
%28 : int = aten::add(%25, %27)
%y.10 : int = aten::add(%y.9, %28)
return (%y.10)

View File

@ -1,48 +0,0 @@
graph(%x : Tensor):
%1 : bool = prim::Constant[value=1]()
%y.1 : int = prim::Constant[value=0]()
%3 : int = prim::Constant[value=10]()
%y : int = prim::Loop(%3, %1, %y.1)
block0(%5 : int, %6 : int):
%7 : int = prim::Int(%x)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=8]()
%10 : int = aten::__round_to_zero_floordiv(%7, %9)
%11 : int = prim::Constant[value=8]()
%12 : int = aten::mul(%10, %11)
%13 : int = aten::sub(%7, %12)
%14 : int, %y.4 : int = prim::Loop(%10, %1, %8, %6)
block0(%j.1 : int, %17 : int, %18 : int):
%y.13 : int = aten::add(%18, %17)
%20 : int = prim::Constant[value=1]()
%21 : int = aten::add(%17, %20)
%y.6 : int = aten::add(%y.13, %21)
%23 : int = prim::Constant[value=1]()
%24 : int = aten::add(%21, %23)
%y.7 : int = aten::add(%y.6, %24)
%26 : int = prim::Constant[value=1]()
%27 : int = aten::add(%24, %26)
%y.8 : int = aten::add(%y.7, %27)
%29 : int = prim::Constant[value=1]()
%30 : int = aten::add(%27, %29)
%y.9 : int = aten::add(%y.8, %30)
%32 : int = prim::Constant[value=1]()
%33 : int = aten::add(%30, %32)
%y.10 : int = aten::add(%y.9, %33)
%35 : int = prim::Constant[value=1]()
%36 : int = aten::add(%33, %35)
%y.11 : int = aten::add(%y.10, %36)
%38 : int = prim::Constant[value=1]()
%39 : int = aten::add(%36, %38)
%y.12 : int = aten::add(%y.11, %39)
%41 : int = prim::Constant[value=1]()
%42 : int = aten::add(%39, %41)
-> (%1, %42, %y.12)
%43 : int, %y.3 : int = prim::Loop(%13, %1, %14, %y.4)
block0(%j : int, %46 : int, %47 : int):
%y.5 : int = aten::add(%47, %46)
%49 : int = prim::Constant[value=1]()
%50 : int = aten::add(%46, %49)
-> (%1, %50, %y.5)
-> (%1, %y.3)
return (%y)

View File

@ -8015,31 +8015,34 @@ a")
def fn(x): def fn(x):
y = 0 y = 0
for i in range(int(x)): for i in range(int(x)):
y += i y -= i
return y return y
graph = torch.jit.script(fn).graph graph = torch.jit.script(fn).graph
self.run_pass('loop_unrolling', graph) self.run_pass('loop_unrolling', graph)
self.assertExpectedGraph(graph) unroll_factor = 8
FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \
.check("prim::Loop").check("aten::sub").run(str(graph))
self.checkScript(fn, (torch.tensor(10),)) self.checkScript(fn, (torch.tensor(10),))
def test_loop_unrolling_const(self): def test_loop_unrolling_const(self):
def fn(): def fn():
y = 0 y = 0
for _ in range(10): for _ in range(10):
y += 1 y -= 1
return y return y
def fn2(): def fn2():
y = 0 y = 0
for i in range(10): for i in range(10):
y += i y -= i
return y return y
def check(fn, name): def check(fn, name):
graph = torch.jit.script(fn).graph graph = torch.jit.script(fn).graph
self.run_pass('loop_unrolling', graph) self.run_pass('loop_unrolling', graph)
self.assertExpectedGraph(graph, subname=name) # entirely unrolled
FileCheck().check_not("prim::Loop'").run(str(graph))
self.checkScript(fn, ()) self.checkScript(fn, ())
check(fn, 'add_const') check(fn, 'add_const')
@ -8050,24 +8053,28 @@ a")
y = 0 y = 0
for _ in range(10): for _ in range(10):
for j in range(int(x)): for j in range(int(x)):
y += j y -= j
return y return y
graph = torch.jit.script(fn).graph graph = torch.jit.script(fn).graph
self.run_pass('loop_unrolling', graph) self.run_pass('loop_unrolling', graph)
self.assertExpectedGraph(graph) # inner loop with 8 subs followed by loop epilogue
unroll_factor = 8
FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \
.check("prim::Loop").check("aten::sub").run(str(graph))
self.checkScript(fn, (torch.tensor(10),)) self.checkScript(fn, (torch.tensor(10),))
def test_loop_unroll_unused_counter(self): def test_loop_unroll_unused_counter(self):
def fn(x): def fn(x):
y = 0 y = 0
for _ in range(int(x)): for _ in range(int(x)):
y += 1 y -= 1
return y return y
graph = torch.jit.script(fn).graph graph = torch.jit.script(fn).graph
self.run_pass('loop_unrolling', graph) self.run_pass('loop_unrolling', graph)
self.assertExpectedGraph(graph) FileCheck().check("prim::Loop").check_not("aten::add").check("return") \
.run(str(graph))
def test_loop_unroll_negative(self): def test_loop_unroll_negative(self):
def fn(x): def fn(x):