mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
remove loop expects (#17695)
Summary: Replace loop unrolling expect files with assertions on the output IR Pull Request resolved: https://github.com/pytorch/pytorch/pull/17695 Differential Revision: D14347105 Pulled By: eellison fbshipit-source-id: 1703b4ca32bc1c67c01fc4330b0e6eb66feaa103
This commit is contained in:
committed by
Facebook Github Bot
parent
b87abdfc12
commit
7fa996f8e2
@ -1,44 +0,0 @@
|
|||||||
graph(%x : Tensor):
|
|
||||||
%1 : bool = prim::Constant[value=1]()
|
|
||||||
%y.1 : int = prim::Constant[value=0]()
|
|
||||||
%3 : int = prim::Int(%x)
|
|
||||||
%4 : int = prim::Constant[value=0]()
|
|
||||||
%5 : int = prim::Constant[value=8]()
|
|
||||||
%6 : int = aten::__round_to_zero_floordiv(%3, %5)
|
|
||||||
%7 : int = prim::Constant[value=8]()
|
|
||||||
%8 : int = aten::mul(%6, %7)
|
|
||||||
%9 : int = aten::sub(%3, %8)
|
|
||||||
%10 : int, %y.3 : int = prim::Loop(%6, %1, %4, %y.1)
|
|
||||||
block0(%i.1 : int, %13 : int, %14 : int):
|
|
||||||
%y.12 : int = aten::add(%14, %13)
|
|
||||||
%16 : int = prim::Constant[value=1]()
|
|
||||||
%17 : int = aten::add(%13, %16)
|
|
||||||
%y.5 : int = aten::add(%y.12, %17)
|
|
||||||
%19 : int = prim::Constant[value=1]()
|
|
||||||
%20 : int = aten::add(%17, %19)
|
|
||||||
%y.6 : int = aten::add(%y.5, %20)
|
|
||||||
%22 : int = prim::Constant[value=1]()
|
|
||||||
%23 : int = aten::add(%20, %22)
|
|
||||||
%y.7 : int = aten::add(%y.6, %23)
|
|
||||||
%25 : int = prim::Constant[value=1]()
|
|
||||||
%26 : int = aten::add(%23, %25)
|
|
||||||
%y.8 : int = aten::add(%y.7, %26)
|
|
||||||
%28 : int = prim::Constant[value=1]()
|
|
||||||
%29 : int = aten::add(%26, %28)
|
|
||||||
%y.9 : int = aten::add(%y.8, %29)
|
|
||||||
%31 : int = prim::Constant[value=1]()
|
|
||||||
%32 : int = aten::add(%29, %31)
|
|
||||||
%y.10 : int = aten::add(%y.9, %32)
|
|
||||||
%34 : int = prim::Constant[value=1]()
|
|
||||||
%35 : int = aten::add(%32, %34)
|
|
||||||
%y.11 : int = aten::add(%y.10, %35)
|
|
||||||
%37 : int = prim::Constant[value=1]()
|
|
||||||
%38 : int = aten::add(%35, %37)
|
|
||||||
-> (%1, %38, %y.11)
|
|
||||||
%39 : int, %y : int = prim::Loop(%9, %1, %10, %y.3)
|
|
||||||
block0(%i : int, %42 : int, %43 : int):
|
|
||||||
%y.4 : int = aten::add(%43, %42)
|
|
||||||
%45 : int = prim::Constant[value=1]()
|
|
||||||
%46 : int = aten::add(%42, %45)
|
|
||||||
-> (%1, %46, %y.4)
|
|
||||||
return (%y)
|
|
@ -1,14 +0,0 @@
|
|||||||
graph():
|
|
||||||
%y.1 : int = prim::Constant[value=0]()
|
|
||||||
%1 : int = prim::Constant[value=1]()
|
|
||||||
%y.11 : int = aten::add(%y.1, %1)
|
|
||||||
%y.2 : int = aten::add(%y.11, %1)
|
|
||||||
%y.3 : int = aten::add(%y.2, %1)
|
|
||||||
%y.4 : int = aten::add(%y.3, %1)
|
|
||||||
%y.5 : int = aten::add(%y.4, %1)
|
|
||||||
%y.6 : int = aten::add(%y.5, %1)
|
|
||||||
%y.7 : int = aten::add(%y.6, %1)
|
|
||||||
%y.8 : int = aten::add(%y.7, %1)
|
|
||||||
%y.9 : int = aten::add(%y.8, %1)
|
|
||||||
%y.10 : int = aten::add(%y.9, %1)
|
|
||||||
return (%y.10)
|
|
@ -1,32 +0,0 @@
|
|||||||
graph():
|
|
||||||
%y.1 : int = prim::Constant[value=0]()
|
|
||||||
%1 : int = prim::Constant[value=0]()
|
|
||||||
%y.11 : int = aten::add(%y.1, %1)
|
|
||||||
%3 : int = prim::Constant[value=1]()
|
|
||||||
%4 : int = aten::add(%1, %3)
|
|
||||||
%y.2 : int = aten::add(%y.11, %4)
|
|
||||||
%6 : int = prim::Constant[value=1]()
|
|
||||||
%7 : int = aten::add(%4, %6)
|
|
||||||
%y.3 : int = aten::add(%y.2, %7)
|
|
||||||
%9 : int = prim::Constant[value=1]()
|
|
||||||
%10 : int = aten::add(%7, %9)
|
|
||||||
%y.4 : int = aten::add(%y.3, %10)
|
|
||||||
%12 : int = prim::Constant[value=1]()
|
|
||||||
%13 : int = aten::add(%10, %12)
|
|
||||||
%y.5 : int = aten::add(%y.4, %13)
|
|
||||||
%15 : int = prim::Constant[value=1]()
|
|
||||||
%16 : int = aten::add(%13, %15)
|
|
||||||
%y.6 : int = aten::add(%y.5, %16)
|
|
||||||
%18 : int = prim::Constant[value=1]()
|
|
||||||
%19 : int = aten::add(%16, %18)
|
|
||||||
%y.7 : int = aten::add(%y.6, %19)
|
|
||||||
%21 : int = prim::Constant[value=1]()
|
|
||||||
%22 : int = aten::add(%19, %21)
|
|
||||||
%y.8 : int = aten::add(%y.7, %22)
|
|
||||||
%24 : int = prim::Constant[value=1]()
|
|
||||||
%25 : int = aten::add(%22, %24)
|
|
||||||
%y.9 : int = aten::add(%y.8, %25)
|
|
||||||
%27 : int = prim::Constant[value=1]()
|
|
||||||
%28 : int = aten::add(%25, %27)
|
|
||||||
%y.10 : int = aten::add(%y.9, %28)
|
|
||||||
return (%y.10)
|
|
@ -1,48 +0,0 @@
|
|||||||
graph(%x : Tensor):
|
|
||||||
%1 : bool = prim::Constant[value=1]()
|
|
||||||
%y.1 : int = prim::Constant[value=0]()
|
|
||||||
%3 : int = prim::Constant[value=10]()
|
|
||||||
%y : int = prim::Loop(%3, %1, %y.1)
|
|
||||||
block0(%5 : int, %6 : int):
|
|
||||||
%7 : int = prim::Int(%x)
|
|
||||||
%8 : int = prim::Constant[value=0]()
|
|
||||||
%9 : int = prim::Constant[value=8]()
|
|
||||||
%10 : int = aten::__round_to_zero_floordiv(%7, %9)
|
|
||||||
%11 : int = prim::Constant[value=8]()
|
|
||||||
%12 : int = aten::mul(%10, %11)
|
|
||||||
%13 : int = aten::sub(%7, %12)
|
|
||||||
%14 : int, %y.4 : int = prim::Loop(%10, %1, %8, %6)
|
|
||||||
block0(%j.1 : int, %17 : int, %18 : int):
|
|
||||||
%y.13 : int = aten::add(%18, %17)
|
|
||||||
%20 : int = prim::Constant[value=1]()
|
|
||||||
%21 : int = aten::add(%17, %20)
|
|
||||||
%y.6 : int = aten::add(%y.13, %21)
|
|
||||||
%23 : int = prim::Constant[value=1]()
|
|
||||||
%24 : int = aten::add(%21, %23)
|
|
||||||
%y.7 : int = aten::add(%y.6, %24)
|
|
||||||
%26 : int = prim::Constant[value=1]()
|
|
||||||
%27 : int = aten::add(%24, %26)
|
|
||||||
%y.8 : int = aten::add(%y.7, %27)
|
|
||||||
%29 : int = prim::Constant[value=1]()
|
|
||||||
%30 : int = aten::add(%27, %29)
|
|
||||||
%y.9 : int = aten::add(%y.8, %30)
|
|
||||||
%32 : int = prim::Constant[value=1]()
|
|
||||||
%33 : int = aten::add(%30, %32)
|
|
||||||
%y.10 : int = aten::add(%y.9, %33)
|
|
||||||
%35 : int = prim::Constant[value=1]()
|
|
||||||
%36 : int = aten::add(%33, %35)
|
|
||||||
%y.11 : int = aten::add(%y.10, %36)
|
|
||||||
%38 : int = prim::Constant[value=1]()
|
|
||||||
%39 : int = aten::add(%36, %38)
|
|
||||||
%y.12 : int = aten::add(%y.11, %39)
|
|
||||||
%41 : int = prim::Constant[value=1]()
|
|
||||||
%42 : int = aten::add(%39, %41)
|
|
||||||
-> (%1, %42, %y.12)
|
|
||||||
%43 : int, %y.3 : int = prim::Loop(%13, %1, %14, %y.4)
|
|
||||||
block0(%j : int, %46 : int, %47 : int):
|
|
||||||
%y.5 : int = aten::add(%47, %46)
|
|
||||||
%49 : int = prim::Constant[value=1]()
|
|
||||||
%50 : int = aten::add(%46, %49)
|
|
||||||
-> (%1, %50, %y.5)
|
|
||||||
-> (%1, %y.3)
|
|
||||||
return (%y)
|
|
@ -8015,31 +8015,34 @@ a")
|
|||||||
def fn(x):
|
def fn(x):
|
||||||
y = 0
|
y = 0
|
||||||
for i in range(int(x)):
|
for i in range(int(x)):
|
||||||
y += i
|
y -= i
|
||||||
return y
|
return y
|
||||||
|
|
||||||
graph = torch.jit.script(fn).graph
|
graph = torch.jit.script(fn).graph
|
||||||
self.run_pass('loop_unrolling', graph)
|
self.run_pass('loop_unrolling', graph)
|
||||||
self.assertExpectedGraph(graph)
|
unroll_factor = 8
|
||||||
|
FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \
|
||||||
|
.check("prim::Loop").check("aten::sub").run(str(graph))
|
||||||
self.checkScript(fn, (torch.tensor(10),))
|
self.checkScript(fn, (torch.tensor(10),))
|
||||||
|
|
||||||
def test_loop_unrolling_const(self):
|
def test_loop_unrolling_const(self):
|
||||||
def fn():
|
def fn():
|
||||||
y = 0
|
y = 0
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
y += 1
|
y -= 1
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def fn2():
|
def fn2():
|
||||||
y = 0
|
y = 0
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
y += i
|
y -= i
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def check(fn, name):
|
def check(fn, name):
|
||||||
graph = torch.jit.script(fn).graph
|
graph = torch.jit.script(fn).graph
|
||||||
self.run_pass('loop_unrolling', graph)
|
self.run_pass('loop_unrolling', graph)
|
||||||
self.assertExpectedGraph(graph, subname=name)
|
# entirely unrolled
|
||||||
|
FileCheck().check_not("prim::Loop'").run(str(graph))
|
||||||
self.checkScript(fn, ())
|
self.checkScript(fn, ())
|
||||||
|
|
||||||
check(fn, 'add_const')
|
check(fn, 'add_const')
|
||||||
@ -8050,24 +8053,28 @@ a")
|
|||||||
y = 0
|
y = 0
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
for j in range(int(x)):
|
for j in range(int(x)):
|
||||||
y += j
|
y -= j
|
||||||
return y
|
return y
|
||||||
|
|
||||||
graph = torch.jit.script(fn).graph
|
graph = torch.jit.script(fn).graph
|
||||||
self.run_pass('loop_unrolling', graph)
|
self.run_pass('loop_unrolling', graph)
|
||||||
self.assertExpectedGraph(graph)
|
# inner loop with 8 subs followed by loop epilogue
|
||||||
|
unroll_factor = 8
|
||||||
|
FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \
|
||||||
|
.check("prim::Loop").check("aten::sub").run(str(graph))
|
||||||
self.checkScript(fn, (torch.tensor(10),))
|
self.checkScript(fn, (torch.tensor(10),))
|
||||||
|
|
||||||
def test_loop_unroll_unused_counter(self):
|
def test_loop_unroll_unused_counter(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
y = 0
|
y = 0
|
||||||
for _ in range(int(x)):
|
for _ in range(int(x)):
|
||||||
y += 1
|
y -= 1
|
||||||
return y
|
return y
|
||||||
|
|
||||||
graph = torch.jit.script(fn).graph
|
graph = torch.jit.script(fn).graph
|
||||||
self.run_pass('loop_unrolling', graph)
|
self.run_pass('loop_unrolling', graph)
|
||||||
self.assertExpectedGraph(graph)
|
FileCheck().check("prim::Loop").check_not("aten::add").check("return") \
|
||||||
|
.run(str(graph))
|
||||||
|
|
||||||
def test_loop_unroll_negative(self):
|
def test_loop_unroll_negative(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
Reference in New Issue
Block a user