constant pooling pass (#12222)

Summary:
Add a pass to move all constants to the beginning of the graph, and deduplicate.

This extends https://github.com/pytorch/pytorch/pull/10231 to also handle constants introduced in inlining, constant propagation, etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12222

Reviewed By: driazati

Differential Revision: D10201616

Pulled By: eellison

fbshipit-source-id: bc9c5be26868c8b5414257a0d4462de025aeb9bd
This commit is contained in:
Elias Ellison
2018-10-08 11:50:51 -07:00
committed by Facebook Github Bot
parent 83b4dc6822
commit 00aedfc0e2
52 changed files with 621 additions and 501 deletions

View File

@ -4,19 +4,18 @@ graph(%x.1_data : Dynamic
%y_data : Dynamic
%y_mask : Dynamic
%y_dims : Dynamic) {
%6 : int = prim::Constant[value=10]()
%6 : int = prim::Constant[value=1]()
%7 : bool = prim::Constant[value=1]()
%x : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::Loop(%6, %7, %x.1_data, %x.1_mask, %x.1_dims)
%8 : int = prim::Constant[value=10]()
%x : Dynamic, %10 : Dynamic, %11 : Dynamic = prim::Loop(%8, %7, %x.1_data, %x.1_mask, %x.1_dims)
block0(%loop_num : int, %5_data : Dynamic, %5_mask : Dynamic, %5_dims : Dynamic) {
%15 : int = prim::Constant[value=1]()
%16 : Long() = prim::NumToTensor(%15)
%16 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%16)
%data.1 : Dynamic = aten::add(%5_data, %y_data, %alpha)
%mask : Dynamic = aten::mul(%5_mask, %y_mask)
%dims : Dynamic = aten::__or__(%5_dims, %y_dims)
%21 : bool = prim::Constant[value=1]()
%data : Dynamic = aten::where(%mask, %data.1, %5_data)
-> (%21, %data, %mask, %dims)
-> (%7, %data, %mask, %dims)
}
return (%x, %9, %10);
return (%x, %10, %11);
}

View File

@ -4,38 +4,36 @@ graph(%a.1_data : Dynamic
%b_data : Dynamic
%b_mask : Dynamic
%b_dims : Dynamic) {
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%9 : bool = prim::TensorToBool(%6)
%10 : int = prim::Constant[value=1]()
%11 : Long() = prim::NumToTensor(%10)
%6 : int = prim::Constant[value=1]()
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%10 : bool = prim::TensorToBool(%7)
%11 : Long() = prim::NumToTensor(%6)
%alpha.1 : float = prim::TensorToNum(%11)
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
%mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%16 : int = prim::Constant[value=1]()
%17 : Long() = prim::NumToTensor(%16)
%alpha : float = prim::TensorToNum(%17)
%16 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%16)
%data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%21 : bool = prim::Constant[value=1]()
%22 : int = prim::Constant[value=1]()
%23 : Dynamic = aten::type_as(%7, %6)
%cond_mask.1 : Dynamic = aten::mul(%6, %23)
%23 : Dynamic = aten::type_as(%8, %7)
%cond_mask.1 : Dynamic = aten::mul(%7, %23)
%25 : int = aten::dim(%cond_mask.1)
%26 : bool = aten::eq(%25, %22)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26)
block0() {
%30 : int = aten::dim(%data.1)
%31 : int = aten::sub(%30, %22)
%32 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%31, %32, %cond_mask.1)
block0(%_ : int, %35 : Dynamic) {
%36 : int = aten::dim(%35)
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
%38 : bool = prim::Constant[value=1]()
-> (%38, %data.2)
%data.3 : Dynamic = prim::Loop(%31, %21, %cond_mask.1)
block0(%_ : int, %34 : Dynamic) {
%35 : int = aten::dim(%34)
%data.2 : Dynamic = aten::unsqueeze(%34, %35)
-> (%21, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)

View File

@ -4,39 +4,37 @@ graph(%a.1_data : Dynamic
%b_data : Dynamic
%b_mask : Dynamic
%b_dims : Dynamic) {
%6 : float = prim::Constant[value=0.1]()
%7 : Float() = prim::NumToTensor(%6)
%other : float = prim::TensorToNum(%7)
%9 : Dynamic = aten::gt(%a.1_data, %other)
%10 : bool = prim::TensorToBool(%9)
%11 : int = prim::Constant[value=1]()
%12 : Long() = prim::NumToTensor(%11)
%6 : int = prim::Constant[value=1]()
%7 : float = prim::Constant[value=0.1]()
%8 : Float() = prim::NumToTensor(%7)
%other : float = prim::TensorToNum(%8)
%10 : Dynamic = aten::gt(%a.1_data, %other)
%11 : bool = prim::TensorToBool(%10)
%12 : Long() = prim::NumToTensor(%6)
%alpha.1 : float = prim::TensorToNum(%12)
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
%mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%17 : int = prim::Constant[value=1]()
%18 : Long() = prim::NumToTensor(%17)
%alpha : float = prim::TensorToNum(%18)
%17 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%17)
%data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%22 : bool = prim::Constant[value=1]()
%23 : int = prim::Constant[value=1]()
%24 : Dynamic = aten::type_as(%a.1_mask, %9)
%cond_mask.1 : Dynamic = aten::mul(%9, %24)
%24 : Dynamic = aten::type_as(%a.1_mask, %10)
%cond_mask.1 : Dynamic = aten::mul(%10, %24)
%26 : int = aten::dim(%cond_mask.1)
%27 : bool = aten::eq(%26, %23)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27)
block0() {
%31 : int = aten::dim(%data.1)
%32 : int = aten::sub(%31, %23)
%33 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%32, %33, %cond_mask.1)
block0(%_ : int, %36 : Dynamic) {
%37 : int = aten::dim(%36)
%data.2 : Dynamic = aten::unsqueeze(%36, %37)
%39 : bool = prim::Constant[value=1]()
-> (%39, %data.2)
%data.3 : Dynamic = prim::Loop(%32, %22, %cond_mask.1)
block0(%_ : int, %35 : Dynamic) {
%36 : int = aten::dim(%35)
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
-> (%22, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)

View File

@ -4,32 +4,31 @@ graph(%a.1_data : Dynamic
%b_data : Dynamic
%b_mask : Dynamic
%b_dims : Dynamic) {
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%9 : bool = prim::TensorToBool(%6)
%10 : int = prim::Constant[value=1]()
%11 : Long() = prim::NumToTensor(%10)
%6 : int = prim::Constant[value=1]()
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%10 : bool = prim::TensorToBool(%7)
%11 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%11)
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%16 : int = prim::Constant[value=1]()
%17 : Dynamic = aten::type_as(%7, %6)
%cond_mask.1 : Dynamic = aten::mul(%6, %17)
%19 : int = aten::dim(%cond_mask.1)
%20 : bool = aten::eq(%19, %16)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%20)
%16 : bool = prim::Constant[value=1]()
%17 : int = prim::Constant[value=1]()
%18 : Dynamic = aten::type_as(%8, %7)
%cond_mask.1 : Dynamic = aten::mul(%7, %18)
%20 : int = aten::dim(%cond_mask.1)
%21 : bool = aten::eq(%20, %17)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
block0() {
%24 : int = aten::dim(%data.1)
%25 : int = aten::sub(%24, %16)
%26 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%25, %26, %cond_mask.1)
%25 : int = aten::dim(%data.1)
%26 : int = aten::sub(%25, %17)
%data.3 : Dynamic = prim::Loop(%26, %16, %cond_mask.1)
block0(%_ : int, %29 : Dynamic) {
%30 : int = aten::dim(%29)
%data.2 : Dynamic = aten::unsqueeze(%29, %30)
%32 : bool = prim::Constant[value=1]()
-> (%32, %data.2)
-> (%16, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)

View File

@ -4,33 +4,32 @@ graph(%a.1_data : Dynamic
%b_data : Dynamic
%b_mask : Dynamic
%b_dims : Dynamic) {
%6 : float = prim::Constant[value=0.1]()
%7 : Float() = prim::NumToTensor(%6)
%other : float = prim::TensorToNum(%7)
%9 : Dynamic = aten::gt(%a.1_data, %other)
%10 : bool = prim::TensorToBool(%9)
%11 : int = prim::Constant[value=1]()
%12 : Long() = prim::NumToTensor(%11)
%6 : int = prim::Constant[value=1]()
%7 : float = prim::Constant[value=0.1]()
%8 : Float() = prim::NumToTensor(%7)
%other : float = prim::TensorToNum(%8)
%10 : Dynamic = aten::gt(%a.1_data, %other)
%11 : bool = prim::TensorToBool(%10)
%12 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%12)
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%17 : int = prim::Constant[value=1]()
%18 : Dynamic = aten::type_as(%a.1_mask, %9)
%cond_mask.1 : Dynamic = aten::mul(%9, %18)
%20 : int = aten::dim(%cond_mask.1)
%21 : bool = aten::eq(%20, %17)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
%17 : bool = prim::Constant[value=1]()
%18 : int = prim::Constant[value=1]()
%19 : Dynamic = aten::type_as(%a.1_mask, %10)
%cond_mask.1 : Dynamic = aten::mul(%10, %19)
%21 : int = aten::dim(%cond_mask.1)
%22 : bool = aten::eq(%21, %18)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%22)
block0() {
%25 : int = aten::dim(%data.1)
%26 : int = aten::sub(%25, %17)
%27 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%26, %27, %cond_mask.1)
%26 : int = aten::dim(%data.1)
%27 : int = aten::sub(%26, %18)
%data.3 : Dynamic = prim::Loop(%27, %17, %cond_mask.1)
block0(%_ : int, %30 : Dynamic) {
%31 : int = aten::dim(%30)
%data.2 : Dynamic = aten::unsqueeze(%30, %31)
%33 : bool = prim::Constant[value=1]()
-> (%33, %data.2)
-> (%17, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)

View File

@ -4,20 +4,20 @@ graph(%a.1_data : Dynamic
%b_data : Dynamic
%b_mask : Dynamic
%b_dims : Dynamic) {
%6 : int = prim::Constant[value=2147483647]()
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%10 : bool = prim::TensorToBool(%7)
%11 : int = prim::Constant[value=0]()
%12 : Dynamic = aten::mul(%7, %8)
%13 : Dynamic = aten::sum(%12)
%14 : Dynamic = aten::gt(%13, %11)
%15 : bool = prim::TensorToBool(%14)
%16 : Dynamic, %17 : Dynamic, %18 : Dynamic, %a : Dynamic, %20 : Dynamic, %21 : Dynamic = prim::Loop(%6, %15, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
%6 : int = prim::Constant[value=1]()
%7 : int = prim::Constant[value=2147483647]()
%8 : Dynamic = aten::gt(%a.1_data, %b_data)
%9 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%10 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%11 : bool = prim::TensorToBool(%8)
%12 : int = prim::Constant[value=0]()
%13 : Dynamic = aten::mul(%8, %9)
%14 : Dynamic = aten::sum(%13)
%15 : Dynamic = aten::gt(%14, %12)
%16 : bool = prim::TensorToBool(%15)
%17 : Dynamic, %18 : Dynamic, %19 : Dynamic, %a : Dynamic, %21 : Dynamic, %22 : Dynamic = prim::Loop(%7, %16, %8, %9, %10, %a.1_data, %a.1_mask, %a.1_dims)
block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
%29 : int = prim::Constant[value=1]()
%30 : Long() = prim::NumToTensor(%29)
%30 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%30)
%data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%6_mask, %b_mask)
@ -26,22 +26,21 @@ graph(%a.1_data : Dynamic
%36 : Dynamic = aten::mul(%mask, %b_mask)
%37 : Dynamic = aten::__or__(%dims, %b_dims)
%38 : bool = prim::TensorToBool(%35)
%39 : int = prim::Constant[value=1]()
%40 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
%cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %40)
%42 : int = aten::dim(%cond_mask.1)
%43 : bool = aten::eq(%42, %39)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%43)
%39 : bool = prim::Constant[value=1]()
%40 : int = prim::Constant[value=1]()
%41 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
%cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %41)
%43 : int = aten::dim(%cond_mask.1)
%44 : bool = aten::eq(%43, %40)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%44)
block0() {
%47 : int = aten::dim(%data.1)
%48 : int = aten::sub(%47, %39)
%49 : bool = prim::Constant[value=1]()
%data.3 : Dynamic = prim::Loop(%48, %49, %cond_mask.1)
%48 : int = aten::dim(%data.1)
%49 : int = aten::sub(%48, %40)
%data.3 : Dynamic = prim::Loop(%49, %39, %cond_mask.1)
block0(%_ : int, %52 : Dynamic) {
%53 : int = aten::dim(%52)
%data.2 : Dynamic = aten::unsqueeze(%52, %53)
%55 : bool = prim::Constant[value=1]()
-> (%55, %data.2)
-> (%39, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
@ -53,12 +52,12 @@ graph(%a.1_data : Dynamic
%res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
%res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
%res_dims : Dynamic = aten::__or__(%dims, %6_dims)
%61 : int = prim::Constant[value=0]()
%62 : Dynamic = aten::mul(%35, %36)
%63 : Dynamic = aten::sum(%62)
%64 : Dynamic = aten::gt(%63, %61)
%65 : bool = prim::TensorToBool(%64)
-> (%65, %35, %36, %37, %res_data, %res_mask, %res_dims)
%60 : int = prim::Constant[value=0]()
%61 : Dynamic = aten::mul(%35, %36)
%62 : Dynamic = aten::sum(%61)
%63 : Dynamic = aten::gt(%62, %60)
%64 : bool = prim::TensorToBool(%63)
-> (%64, %35, %36, %37, %res_data, %res_mask, %res_dims)
}
return (%a, %20, %21);
return (%a, %21, %22);
}

View File

@ -21,11 +21,8 @@ graph(%a : Dynamic
}
%c0.5 : int = aten::add(%c0.4, %c2.1)
%11 : int = prim::Constant[value=5]()
%12 : int = prim::Constant[value=1]()
%13 : Dynamic = aten::add(%a, %c0.5, %12)
%14 : int = prim::Constant[value=1]()
%15 : Dynamic = aten::add(%13, %c1, %14)
%16 : int = prim::Constant[value=1]()
%17 : Dynamic = aten::add(%15, %11, %16)
return (%17);
%12 : Dynamic = aten::add(%a, %c0.5, %c2.1)
%13 : Dynamic = aten::add(%12, %c1, %c2.1)
%14 : Dynamic = aten::add(%13, %11, %c2.1)
return (%14);
}

View File

@ -1,19 +1,17 @@
graph() {
%b.4 : int = prim::Constant[value=2]()
%b.2 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2147483647]()
%0 : bool = prim::Constant[value=0]()
%1 : bool = prim::Constant[value=1]()
%b.1 : int = prim::Constant[value=0]()
%4 : bool = prim::Constant[value=1]()
%b.3 : int = prim::Loop(%2, %4, %b.1)
block0(%6 : int, %7 : int) {
%8 : bool = prim::Constant[value=1]()
-> (%8, %b.2)
%3 : int = prim::Constant[value=2147483647]()
%b.2 : int = prim::Constant[value=1]()
%b.4 : int = prim::Constant[value=2]()
%b.3 : int = prim::Loop(%3, %1, %b.1)
block0(%7 : int, %8 : int) {
-> (%1, %b.2)
}
%9 : bool = prim::Constant[value=0]()
%b : int = prim::Loop(%2, %9, %b.3)
block0(%11 : int, %12 : int) {
%13 : bool = prim::Constant[value=0]()
-> (%13, %b.4)
%b : int = prim::Loop(%3, %0, %b.3)
block0(%10 : int, %11 : int) {
-> (%0, %b.4)
}
return (%b);
}

View File

@ -1,8 +1,8 @@
graph(%input_tensor : Dynamic) {
%1 : int = prim::Constant[value=6]()
= prim::Print(%1)
%2 : int = prim::Constant[value=8]()
%3 : int = prim::Constant[value=1]()
%4 : Dynamic = aten::add(%input_tensor, %2, %3)
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=6]()
= prim::Print(%2)
%3 : int = prim::Constant[value=8]()
%4 : Dynamic = aten::add(%input_tensor, %3, %1)
return (%4);
}

View File

@ -1,11 +1,11 @@
graph() {
%0 : int = prim::Constant[value=2]()
%1 : int[] = prim::Constant[value=[3]]()
%2 : int = prim::Constant[value=6]()
%3 : int = prim::Constant[value=0]()
%4 : int[] = prim::Constant[value=[0, -1]]()
%a : Dynamic = aten::randn(%1, %2, %3, %4)
%6 : int = prim::Constant[value=1]()
%b : Dynamic = aten::add(%a, %0, %6)
%0 : int = prim::Constant[value=1]()
%1 : int[] = prim::Constant[value=[0, -1]]()
%2 : int = prim::Constant[value=0]()
%3 : int = prim::Constant[value=6]()
%4 : int = prim::Constant[value=2]()
%5 : int[] = prim::Constant[value=[3]]()
%a : Dynamic = aten::randn(%5, %3, %2, %1)
%b : Dynamic = aten::add(%a, %4, %0)
return (%b);
}

View File

@ -1,6 +1,6 @@
graph(%input_tensor : Dynamic) {
%1 : int = prim::Constant[value=8]()
%2 : int = prim::Constant[value=1]()
%3 : Dynamic = aten::add(%input_tensor, %1, %2)
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=8]()
%3 : Dynamic = aten::add(%input_tensor, %2, %1)
return (%3);
}

View File

@ -3,23 +3,21 @@ graph(%mat : Dynamic
%mat2 : Dynamic
%alpha : Dynamic
%beta : Dynamic) {
%5 : float = prim::Constant[value=2]()
%5 : int = prim::Constant[value=1]()
%6 : float = prim::Constant[value=4.2]()
%7 : Dynamic = aten::mm(%mat1, %mat2)
%8 : int = prim::Constant[value=1]()
%9 : Dynamic = aten::add(%mat, %7, %8)
%10 : Dynamic = aten::mm(%mat1, %mat2)
%11 : int = prim::Constant[value=1]()
%12 : Dynamic = aten::add(%mat, %10, %11)
%c : Dynamic = aten::addmm(%mat, %mat1, %mat2, %5, %6)
%14 : int = prim::TensorToNum(%alpha)
%15 : int = prim::TensorToNum(%beta)
%d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %15, %14)
%17 : int = prim::Constant[value=1]()
%18 : Dynamic = aten::add(%9, %12, %17)
%19 : int = prim::Constant[value=1]()
%20 : Dynamic = aten::add(%18, %c, %19)
%21 : int = prim::Constant[value=1]()
%22 : Dynamic = aten::add(%20, %d, %21)
return (%22);
%7 : float = prim::Constant[value=2]()
%8 : Dynamic = aten::mm(%mat1, %mat2)
%9 : int = prim::Constant[value=1]()
%10 : Dynamic = aten::add(%mat, %8, %9)
%11 : Dynamic = aten::mm(%mat1, %mat2)
%12 : int = prim::Constant[value=1]()
%13 : Dynamic = aten::add(%mat, %11, %12)
%c : Dynamic = aten::addmm(%mat, %mat1, %mat2, %7, %6)
%15 : int = prim::TensorToNum(%alpha)
%16 : int = prim::TensorToNum(%beta)
%d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %16, %15)
%18 : Dynamic = aten::add(%10, %13, %5)
%19 : Dynamic = aten::add(%18, %c, %5)
%20 : Dynamic = aten::add(%19, %d, %5)
return (%20);
}

View File

@ -6,27 +6,26 @@ ModelProto {
GraphProto {
name: "torch-jit-export"
inputs: [{name: "y.1", type:Tensor dims: 3 4 1}]
outputs: [{name: "5", type:Tensor dims: 3 4 4}]
outputs: [{name: "6", type:Tensor dims: 3 4 4}]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Loop", inputs: [3,4,y.1], outputs: [5], attributes: [{ name: 'body', type: graph, value:
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Loop", inputs: [3,2,y.1], outputs: [6], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "8", type:Tensor dims: }]
outputs: [{name: "15", type:Tensor dims: },{name: "16", type:Tensor dims: }]
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "9", type:Tensor dims: }]
outputs: [{name: "2", type:Tensor dims: },{name: "15", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Unsqueeze", inputs: [2], outputs: [9], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [1], outputs: [10], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [i], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Concat", inputs: [9,10,11], outputs: [12], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [13], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "ATen", inputs: [y.1,12,13], outputs: [16], attributes: [{ name: 'operator', type: string, value: 'expand'}]},
Node {type: "Constant", inputs: [], outputs: [15], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
Node {type: "Unsqueeze", inputs: [4], outputs: [10], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [5], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [i], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Concat", inputs: [10,11,12], outputs: [13], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "ATen", inputs: [y.1,13,1], outputs: [15], attributes: [{ name: 'operator', type: string, value: 'expand'}]}
]
}

View File

@ -5,13 +5,13 @@ graph(%0 : Float(*, *)
%4 : Float(*, *)
%5 : Float(*)
%6 : Float(*)) {
%7 : Float(*, *) = aten::t(%3)
%8 : Float(*, *) = aten::mm(%0, %7)
%9 : int = prim::Constant[value=1]()
%10 : Float(*, *) = aten::add(%5, %8, %9)
%7 : int = prim::Constant[value=1]()
%8 : Float(*, *) = aten::t(%3)
%9 : Float(*, *) = aten::mm(%0, %8)
%10 : Float(*, *) = aten::add(%5, %9, %7)
%11 : Float(*, *) = aten::t(%4)
%12 : Float(*, *) = aten::mm(%1, %11)
%13 : Float(*, *) = aten::add(%6, %12, %9)
%13 : Float(*, *) = aten::add(%6, %12, %7)
%14 : Dynamic[] = prim::ListConstruct(%10, %13)
%15 : Dynamic[] = aten::broadcast_tensors(%14)
%16 : Dynamic, %17 : Dynamic = prim::ListUnpack(%15)

View File

@ -5,13 +5,13 @@ graph(%0 : Float(*, *)
%4 : Float(*, *)
%5 : Float(*)
%6 : Float(*)) {
%7 : Float(*, *) = aten::t(%3)
%8 : Float(*, *) = aten::mm(%0, %7)
%9 : int = prim::Constant[value=1]()
%10 : Float(*, *) = aten::add(%5, %8, %9)
%7 : int = prim::Constant[value=1]()
%8 : Float(*, *) = aten::t(%3)
%9 : Float(*, *) = aten::mm(%0, %8)
%10 : Float(*, *) = aten::add(%5, %9, %7)
%11 : Float(*, *) = aten::t(%4)
%12 : Float(*, *) = aten::mm(%1, %11)
%13 : Float(*, *) = aten::add(%6, %12, %9)
%13 : Float(*, *) = aten::add(%6, %12, %7)
%14 : Dynamic[] = prim::ListConstruct(%10, %13)
%15 : Dynamic[] = aten::broadcast_tensors(%14)
%16 : Dynamic, %17 : Dynamic = prim::ListUnpack(%15)

View File

@ -1,7 +1,6 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=1]()
%2 : Dynamic = ^python_fn()(%x)
%3 : int = prim::Constant[value=1]()
%4 : Dynamic = aten::add(%2, %1, %3)
return (%4);
%3 : Dynamic = aten::add(%2, %1, %1)
return (%3);
}

View File

@ -1,7 +1,6 @@
graph(%x : Dynamic) {
%2 : int = prim::Constant[value=1]()
%1 : Dynamic = ^<python_value>()(%x)
%3 : int = prim::Constant[value=1]()
%4 : Dynamic = aten::add(%1, %2, %3)
%4 : Dynamic = aten::add(%1, %2, %2)
return (%4);
}

View File

@ -1,7 +1,6 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=1]()
%2 : Dynamic = aten::neg(%x)
%3 : int = prim::Constant[value=1]()
%4 : Dynamic = aten::add(%2, %1, %3)
return (%4);
%3 : Dynamic = aten::add(%2, %1, %1)
return (%3);
}

View File

@ -1,14 +1,13 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=4]()
%4 : int[] = prim::ListConstruct(%3, %2)
%5 : int = prim::Constant[value=6]()
%6 : int = prim::Constant[value=0]()
%7 : int[] = prim::Constant[value=[0, -1]]()
%8 : Dynamic = aten::zeros(%4, %5, %6, %7)
%1 : int = prim::Constant[value=3]()
%2 : int = prim::Constant[value=4]()
%3 : int = prim::Constant[value=6]()
%4 : int = prim::Constant[value=0]()
%5 : int[] = prim::Constant[value=[0, -1]]()
%6 : int = prim::Constant[value=1]()
%7 : int[] = prim::ListConstruct(%2, %1)
%8 : Dynamic = aten::zeros(%7, %3, %4, %5)
%9 : Dynamic = aten::mm(%x, %8)
%10 : int = prim::Constant[value=1]()
%11 : Dynamic = aten::add(%9, %1, %10)
return (%11);
%10 : Dynamic = aten::add(%9, %6, %6)
return (%10);
}

View File

@ -1,7 +1,6 @@
graph(%x : Dynamic) {
%2 : int = prim::Constant[value=1]()
%1 : Double(3, 4) = aten::neg(%x)
%3 : int = prim::Constant[value=1]()
%4 : Dynamic = aten::add(%1, %2, %3)
%4 : Dynamic = aten::add(%1, %2, %2)
return (%4);
}

View File

@ -1,14 +1,13 @@
graph(%x : Dynamic) {
%9 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=4]()
%2 : int = prim::Constant[value=3]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : int = prim::Constant[value=7]()
%5 : int = prim::Constant[value=0]()
%6 : int[] = prim::Constant[value=[0, -1]]()
%5 : int = prim::Constant[value=0]()
%4 : int = prim::Constant[value=7]()
%2 : int = prim::Constant[value=3]()
%1 : int = prim::Constant[value=4]()
%9 : int = prim::Constant[value=1]()
%3 : int[] = prim::ListConstruct(%1, %2)
%7 : Double(4, 3) = aten::zeros(%3, %4, %5, %6)
%8 : Double(3, 3) = aten::mm(%x, %7)
%10 : int = prim::Constant[value=1]()
%11 : Dynamic = aten::add(%8, %9, %10)
%11 : Dynamic = aten::add(%8, %9, %9)
return (%11);
}

View File

@ -0,0 +1,33 @@
graph(%cond : Dynamic) {
%1 : int[] = prim::Constant[value=[1]]()
%2 : int[] = prim::Constant[value=[0]]()
%3 : int = prim::Constant[value=3]()
%4 : Float(2) = prim::Constant[value= 4 4 [ CPUFloatType{2} ]]()
%5 : Float(2) = prim::Constant[value= 1 1 [ CPUFloatType{2} ]]()
%c.1 : int = prim::Constant[value=0]()
%a : int = prim::Constant[value=1]()
%d : string = prim::Constant[value="abc"]()
%e : string = prim::Constant[value="bcd"]()
%10 : int = prim::Constant[value=6]()
%11 : int[] = prim::Constant[value=[0, -1]]()
%12 : bool = prim::TensorToBool(%cond)
%c : int, %y : Dynamic = prim::If(%12)
block0() {
-> (%3, %4)
}
block1() {
%y.2 : Dynamic = aten::rand(%2, %10, %c.1, %11)
%16 : bool = prim::TensorToBool(%cond)
%y.4 : Dynamic = prim::If(%16)
block0() {
%y.3 : Dynamic = aten::rand(%1, %10, %c.1, %11)
-> (%y.3)
}
block1() {
-> (%y.2)
}
= prim::Print(%d, %e, %d, %5, %y.4, %5)
-> (%c.1, %y.4)
}
return (%a, %3, %c, %5);
}

View File

@ -1,12 +1,10 @@
graph(%a : Dynamic) {
%1 : Long() = prim::Constant[value={3}]()
%1 : Long() = prim::Constant[value={7}]()
%2 : Long() = prim::Constant[value={1}]()
%3 : Long() = prim::Constant[value={7}]()
%4 : Long() = aten::add(%3, %2)
%b : Long() = aten::add(%4, %1)
%6 : Long() = prim::Constant[value={1}]()
%c.1 : Dynamic = aten::add(%a, %b, %6)
%8 : Long() = prim::Constant[value={1}]()
%c : Dynamic = aten::add(%c.1, %b, %8)
%3 : Long() = prim::Constant[value={3}]()
%4 : Long() = aten::add(%1, %2)
%b : Long() = aten::add(%4, %3)
%c.1 : Dynamic = aten::add(%a, %b, %2)
%c : Dynamic = aten::add(%c.1, %b, %2)
return (%c);
}

View File

@ -11,25 +11,23 @@ ModelProto {
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Gather", inputs: [x,2], outputs: [3], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Shape", inputs: [x], outputs: [4], attributes: []},
Node {type: "Gather", inputs: [4,1], outputs: [5], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Loop", inputs: [5,6,3], outputs: [7], attributes: [{ name: 'body', type: graph, value:
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Gather", inputs: [x,2], outputs: [4], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Shape", inputs: [x], outputs: [5], attributes: []},
Node {type: "Gather", inputs: [5,3], outputs: [6], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Loop", inputs: [6,1,4], outputs: [7], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "10", type:Tensor dims: }]
outputs: [{name: "18", type:Tensor dims: },{name: "17", type:Tensor dims: }]
outputs: [{name: "1", type:Tensor dims: },{name: "16", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [11], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Unsqueeze", inputs: [2], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [i], outputs: [13], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [11], outputs: [14], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "DynamicSlice", inputs: [x,12,13,14], outputs: [15], attributes: []},
Node {type: "ReduceSum", inputs: [15], outputs: [16], attributes: [{ name: 'axes', type: ints, values: [0]},{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Add", inputs: [10,16], outputs: [17], attributes: []},
Node {type: "Constant", inputs: [], outputs: [18], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
Node {type: "Unsqueeze", inputs: [2], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [i], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [2], outputs: [13], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "DynamicSlice", inputs: [x,11,12,13], outputs: [14], attributes: []},
Node {type: "ReduceSum", inputs: [14], outputs: [15], attributes: [{ name: 'axes', type: ints, values: [0]},{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Add", inputs: [10,15], outputs: [16], attributes: []}
]
}

View File

@ -1,6 +1,7 @@
graph(%x : Double(*, *)) {
%1 : bool = prim::Constant[value=1]()
%c : Dynamic[] = prim::If(%1)
%1 : int = prim::Constant[value=0]()
%2 : bool = prim::Constant[value=1]()
%c : Dynamic[] = prim::If(%2)
block0() {
%c.1 : Dynamic[] = prim::ListConstruct(%x, %x)
-> (%c.1)
@ -9,7 +10,6 @@ graph(%x : Double(*, *)) {
%c.2 : Dynamic[] = prim::ListConstruct(%x, %x, %x)
-> (%c.2)
}
%5 : int = prim::Constant[value=0]()
%6 : Dynamic = aten::cat(%c, %5)
%6 : Dynamic = aten::cat(%c, %1)
return (%6);
}

View File

@ -1,34 +1,32 @@
graph(%t : Dynamic) {
%c1.2 : int = prim::Constant[value=0]()
%1 : bool = prim::Constant[value=1]()
%2 : bool = prim::Constant[value=0]()
%c1.1 : int = prim::Constant[value=1]()
%3 : bool = prim::Constant[value=0]()
%4 : bool = prim::If(%3)
%c1.2 : int = prim::Constant[value=0]()
%5 : bool = prim::If(%2)
block0() {
%5 : int = prim::Constant[value=0]()
%6 : Dynamic = aten::select(%t, %5, %c1.1)
%6 : Dynamic = aten::select(%t, %c1.2, %c1.1)
%7 : bool = prim::TensorToBool(%6)
-> (%7)
}
block1() {
-> (%3)
-> (%2)
}
%8 : bool = prim::If(%4)
%8 : bool = prim::If(%5)
block0() {
-> (%4)
-> (%5)
}
block1() {
%9 : bool = prim::Constant[value=1]()
%10 : bool = prim::If(%9)
%9 : bool = prim::If(%1)
block0() {
-> (%9)
-> (%1)
}
block1() {
%11 : int = prim::Constant[value=0]()
%12 : Dynamic = aten::select(%t, %11, %c1.1)
%13 : bool = prim::TensorToBool(%12)
-> (%13)
%10 : Dynamic = aten::select(%t, %c1.2, %c1.1)
%11 : bool = prim::TensorToBool(%10)
-> (%11)
}
-> (%10)
-> (%9)
}
%c1 : int = prim::If(%8)
block0() {

View File

@ -1,31 +1,29 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=1]()
%1 : bool = prim::Constant[value=1]()
%y.1 : int = prim::Constant[value=0]()
%3 : int = prim::TensorToNum(%x)
%4 : bool = prim::Constant[value=1]()
%3 : int = prim::Constant[value=1]()
%4 : int = prim::TensorToNum(%x)
%5 : int = prim::Constant[value=8]()
%6 : int = aten::floordiv(%3, %5)
%6 : int = aten::floordiv(%4, %5)
%7 : int = prim::Constant[value=8]()
%8 : int = aten::mul(%6, %7)
%9 : int = aten::sub(%3, %8)
%y.3 : int = prim::Loop(%6, %4, %y.1)
%9 : int = aten::sub(%4, %8)
%y.3 : int = prim::Loop(%6, %1, %y.1)
block0(%i.1 : int, %12 : int) {
%y.12 : int = aten::add(%12, %1)
%y.5 : int = aten::add(%y.12, %1)
%y.6 : int = aten::add(%y.5, %1)
%y.7 : int = aten::add(%y.6, %1)
%y.8 : int = aten::add(%y.7, %1)
%y.9 : int = aten::add(%y.8, %1)
%y.10 : int = aten::add(%y.9, %1)
%y.11 : int = aten::add(%y.10, %1)
%21 : bool = prim::Constant[value=1]()
-> (%21, %y.11)
%y.12 : int = aten::add(%12, %3)
%y.5 : int = aten::add(%y.12, %3)
%y.6 : int = aten::add(%y.5, %3)
%y.7 : int = aten::add(%y.6, %3)
%y.8 : int = aten::add(%y.7, %3)
%y.9 : int = aten::add(%y.8, %3)
%y.10 : int = aten::add(%y.9, %3)
%y.11 : int = aten::add(%y.10, %3)
-> (%1, %y.11)
}
%y : int = prim::Loop(%9, %4, %y.3)
block0(%i : int, %24 : int) {
%y.4 : int = aten::add(%24, %1)
%26 : bool = prim::Constant[value=1]()
-> (%26, %y.4)
%y : int = prim::Loop(%9, %1, %y.3)
block0(%i : int, %23 : int) {
%y.4 : int = aten::add(%23, %3)
-> (%1, %y.4)
}
return (%y);
}

View File

@ -1,14 +1,14 @@
graph(%x : Dynamic) {
%1 : bool = prim::Constant[value=1]()
%y.1 : int = prim::Constant[value=0]()
%2 : int = prim::TensorToNum(%x)
%3 : bool = prim::Constant[value=1]()
%3 : int = prim::TensorToNum(%x)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=8]()
%6 : int = aten::floordiv(%2, %5)
%6 : int = aten::floordiv(%3, %5)
%7 : int = prim::Constant[value=8]()
%8 : int = aten::mul(%6, %7)
%9 : int = aten::sub(%2, %8)
%10 : Dynamic, %y.3 : int = prim::Loop(%6, %3, %4, %y.1)
%9 : int = aten::sub(%3, %8)
%10 : Dynamic, %y.3 : int = prim::Loop(%6, %1, %4, %y.1)
block0(%i.1 : int, %13 : int, %14 : int) {
%y.12 : int = aten::add(%14, %13)
%16 : int = prim::Constant[value=1]()
@ -32,18 +32,16 @@ graph(%x : Dynamic) {
%34 : int = prim::Constant[value=1]()
%35 : int = aten::add(%32, %34)
%y.11 : int = aten::add(%y.10, %35)
%37 : bool = prim::Constant[value=1]()
%38 : int = prim::Constant[value=1]()
%39 : int = aten::add(%35, %38)
-> (%37, %39, %y.11)
%37 : int = prim::Constant[value=1]()
%38 : int = aten::add(%35, %37)
-> (%1, %38, %y.11)
}
%40 : Dynamic, %y : int = prim::Loop(%9, %3, %10, %y.3)
block0(%i : int, %43 : int, %44 : int) {
%y.4 : int = aten::add(%44, %43)
%46 : bool = prim::Constant[value=1]()
%47 : int = prim::Constant[value=1]()
%48 : int = aten::add(%43, %47)
-> (%46, %48, %y.4)
%39 : Dynamic, %y : int = prim::Loop(%9, %1, %10, %y.3)
block0(%i : int, %42 : int, %43 : int) {
%y.4 : int = aten::add(%43, %42)
%45 : int = prim::Constant[value=1]()
%46 : int = aten::add(%42, %45)
-> (%1, %46, %y.4)
}
return (%y);
}

View File

@ -1,15 +1,15 @@
graph() {
%0 : int = prim::Constant[value=1]()
%y.1 : int = prim::Constant[value=0]()
%y.11 : int = aten::add(%y.1, %0)
%y.2 : int = aten::add(%y.11, %0)
%y.3 : int = aten::add(%y.2, %0)
%y.4 : int = aten::add(%y.3, %0)
%y.5 : int = aten::add(%y.4, %0)
%y.6 : int = aten::add(%y.5, %0)
%y.7 : int = aten::add(%y.6, %0)
%y.8 : int = aten::add(%y.7, %0)
%y.9 : int = aten::add(%y.8, %0)
%y.10 : int = aten::add(%y.9, %0)
%1 : int = prim::Constant[value=1]()
%y.11 : int = aten::add(%y.1, %1)
%y.2 : int = aten::add(%y.11, %1)
%y.3 : int = aten::add(%y.2, %1)
%y.4 : int = aten::add(%y.3, %1)
%y.5 : int = aten::add(%y.4, %1)
%y.6 : int = aten::add(%y.5, %1)
%y.7 : int = aten::add(%y.6, %1)
%y.8 : int = aten::add(%y.7, %1)
%y.9 : int = aten::add(%y.8, %1)
%y.10 : int = aten::add(%y.9, %1)
return (%y.10);
}

View File

@ -1,56 +1,52 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=10]()
%1 : bool = prim::Constant[value=1]()
%y.1 : int = prim::Constant[value=0]()
%3 : bool = prim::Constant[value=1]()
%y : int = prim::Loop(%1, %3, %y.1)
%3 : int = prim::Constant[value=10]()
%y : int = prim::Loop(%3, %1, %y.1)
block0(%i : int, %6 : int) {
%7 : int = prim::TensorToNum(%x)
%8 : bool = prim::Constant[value=1]()
%9 : int = prim::Constant[value=0]()
%10 : int = prim::Constant[value=8]()
%11 : int = aten::floordiv(%7, %10)
%12 : int = prim::Constant[value=8]()
%13 : int = aten::mul(%11, %12)
%14 : int = aten::sub(%7, %13)
%15 : Dynamic, %y.4 : int = prim::Loop(%11, %8, %9, %6)
block0(%j.1 : int, %18 : int, %19 : int) {
%y.13 : int = aten::add(%19, %18)
%21 : int = prim::Constant[value=1]()
%22 : int = aten::add(%18, %21)
%y.6 : int = aten::add(%y.13, %22)
%24 : int = prim::Constant[value=1]()
%25 : int = aten::add(%22, %24)
%y.7 : int = aten::add(%y.6, %25)
%27 : int = prim::Constant[value=1]()
%28 : int = aten::add(%25, %27)
%y.8 : int = aten::add(%y.7, %28)
%30 : int = prim::Constant[value=1]()
%31 : int = aten::add(%28, %30)
%y.9 : int = aten::add(%y.8, %31)
%33 : int = prim::Constant[value=1]()
%34 : int = aten::add(%31, %33)
%y.10 : int = aten::add(%y.9, %34)
%36 : int = prim::Constant[value=1]()
%37 : int = aten::add(%34, %36)
%y.11 : int = aten::add(%y.10, %37)
%39 : int = prim::Constant[value=1]()
%40 : int = aten::add(%37, %39)
%y.12 : int = aten::add(%y.11, %40)
%42 : bool = prim::Constant[value=1]()
%43 : int = prim::Constant[value=1]()
%44 : int = aten::add(%40, %43)
-> (%42, %44, %y.12)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=8]()
%10 : int = aten::floordiv(%7, %9)
%11 : int = prim::Constant[value=8]()
%12 : int = aten::mul(%10, %11)
%13 : int = aten::sub(%7, %12)
%14 : Dynamic, %y.4 : int = prim::Loop(%10, %1, %8, %6)
block0(%j.1 : int, %17 : int, %18 : int) {
%y.13 : int = aten::add(%18, %17)
%20 : int = prim::Constant[value=1]()
%21 : int = aten::add(%17, %20)
%y.6 : int = aten::add(%y.13, %21)
%23 : int = prim::Constant[value=1]()
%24 : int = aten::add(%21, %23)
%y.7 : int = aten::add(%y.6, %24)
%26 : int = prim::Constant[value=1]()
%27 : int = aten::add(%24, %26)
%y.8 : int = aten::add(%y.7, %27)
%29 : int = prim::Constant[value=1]()
%30 : int = aten::add(%27, %29)
%y.9 : int = aten::add(%y.8, %30)
%32 : int = prim::Constant[value=1]()
%33 : int = aten::add(%30, %32)
%y.10 : int = aten::add(%y.9, %33)
%35 : int = prim::Constant[value=1]()
%36 : int = aten::add(%33, %35)
%y.11 : int = aten::add(%y.10, %36)
%38 : int = prim::Constant[value=1]()
%39 : int = aten::add(%36, %38)
%y.12 : int = aten::add(%y.11, %39)
%41 : int = prim::Constant[value=1]()
%42 : int = aten::add(%39, %41)
-> (%1, %42, %y.12)
}
%45 : Dynamic, %y.3 : int = prim::Loop(%14, %8, %15, %y.4)
block0(%j : int, %48 : int, %49 : int) {
%y.5 : int = aten::add(%49, %48)
%51 : bool = prim::Constant[value=1]()
%52 : int = prim::Constant[value=1]()
%53 : int = aten::add(%48, %52)
-> (%51, %53, %y.5)
%43 : Dynamic, %y.3 : int = prim::Loop(%13, %1, %14, %y.4)
block0(%j : int, %46 : int, %47 : int) {
%y.5 : int = aten::add(%47, %46)
%49 : int = prim::Constant[value=1]()
%50 : int = aten::add(%46, %49)
-> (%1, %50, %y.5)
}
%54 : bool = prim::Constant[value=1]()
-> (%54, %y.3)
-> (%1, %y.3)
}
return (%y);
}

View File

@ -1,6 +1,6 @@
graph(%x : Dynamic) {
%1 : float = prim::Constant[value=3.1]()
%2 : float = prim::Constant[value=1.1]()
%3 : float = aten::add(%2, %1)
%1 : float = prim::Constant[value=1.1]()
%2 : float = prim::Constant[value=3.1]()
%3 : float = aten::add(%1, %2)
return (%3);
}

View File

@ -1,6 +1,6 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=8]()
%2 : int = prim::Constant[value=7]()
%3 : int = aten::add(%2, %1)
%1 : int = prim::Constant[value=7]()
%2 : int = prim::Constant[value=8]()
%3 : int = aten::add(%1, %2)
return (%3);
}

View File

@ -1,6 +1,6 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=7]()
%2 : int = prim::Constant[value=1]()
%3 : Dynamic = aten::add(%x, %1, %2)
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=7]()
%3 : Dynamic = aten::add(%x, %2, %1)
return (%3);
}

View File

@ -11,15 +11,14 @@ ModelProto {
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Loop", inputs: [1,2,x.1], outputs: [3], attributes: [{ name: 'body', type: graph, value:
Node {type: "Loop", inputs: [2,1,x.1], outputs: [3], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: [{name: "_", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "6", type:Tensor dims: }]
outputs: [{name: "8", type:Tensor dims: },{name: "7", type:Tensor dims: }]
outputs: [{name: "1", type:Tensor dims: },{name: "7", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Add", inputs: [6,6], outputs: [7], attributes: []},
Node {type: "Constant", inputs: [], outputs: [8], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
Node {type: "Add", inputs: [6,6], outputs: [7], attributes: []}
]
}

View File

@ -12,9 +12,9 @@ ModelProto {
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Shape", inputs: [x], outputs: [3], attributes: []},
Node {type: "Gather", inputs: [3,2], outputs: [4], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Gather", inputs: [3,1], outputs: [4], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "Cast", inputs: [4], outputs: [5], attributes: [{ name: 'to', type: int, value: 1}]},
Node {type: "Cast", inputs: [1], outputs: [6], attributes: [{ name: 'to', type: int, value: 1}]},
Node {type: "Cast", inputs: [2], outputs: [6], attributes: [{ name: 'to', type: int, value: 1}]},
Node {type: "Div", inputs: [5,6], outputs: [7], attributes: []},
Node {type: "Add", inputs: [x,7], outputs: [8], attributes: []}
]

View File

@ -11,14 +11,14 @@ ModelProto {
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "size", inputs: [x,2], outputs: [3], attributes: []},
Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "_cast_Float", inputs: [3,4], outputs: [5], attributes: []},
Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "_cast_Float", inputs: [1,6], outputs: [7], attributes: []},
Node {type: "div", inputs: [5,7], outputs: [z], attributes: []},
Node {type: "Constant", inputs: [], outputs: [9], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "add", inputs: [x,z,9], outputs: [10], attributes: []}
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "size", inputs: [x,2], outputs: [4], attributes: []},
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "_cast_Float", inputs: [4,5], outputs: [6], attributes: []},
Node {type: "Constant", inputs: [], outputs: [7], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "_cast_Float", inputs: [3,7], outputs: [8], attributes: []},
Node {type: "div", inputs: [6,8], outputs: [z], attributes: []},
Node {type: "add", inputs: [x,z,1], outputs: [10], attributes: []}
]
}
opset_import: [OperatorSetIdProto { domain: }],

View File

@ -1,7 +1,7 @@
graph(%a : Dynamic) {
%2 : int = prim::Constant[value=2]()
%1 : string = prim::Constant[value="a\n\tb\n"]()
%3 : string = prim::Constant[value="aa"]()
%1 : string = prim::Constant[value="a\n\tb\n"]()
%2 : int = prim::Constant[value=2]()
= prim::Print(%a, %1, %2, %3)
return (%a);
}

View File

@ -1,7 +1,7 @@
graph(%x : Dynamic) {
%1 : int = prim::Constant[value=4]()
%2 : int[] = prim::ListConstruct(%1)
%3 : bool = prim::Constant[value=0]()
%4 : Dynamic = aten::sum(%x, %2, %3)
%1 : bool = prim::Constant[value=0]()
%2 : int = prim::Constant[value=4]()
%3 : int[] = prim::ListConstruct(%2)
%4 : Dynamic = aten::sum(%x, %3, %1)
return (%4);
}

View File

@ -1,7 +1,7 @@
graph(%x : Double(*, *, *, *, *)) {
%1 : int = prim::Constant[value=4]()
%2 : int[] = prim::ListConstruct(%1)
%3 : bool = prim::Constant[value=0]()
%4 : Dynamic = aten::sum(%x, %2, %3)
%1 : bool = prim::Constant[value=0]()
%2 : int = prim::Constant[value=4]()
%3 : int[] = prim::ListConstruct(%2)
%4 : Dynamic = aten::sum(%x, %3, %1)
return (%4);
}

View File

@ -1,8 +1,8 @@
graph(%x : Float(*, *)
%z : Float()) {
%2 : int = prim::TensorToNum(%z)
%3 : int = prim::Constant[value=1]()
%y : Float(*, *) = aten::add(%x, %2, %3)
%2 : int = prim::Constant[value=1]()
%3 : int = prim::TensorToNum(%z)
%y : Float(*, *) = aten::add(%x, %3, %2)
%5 : Float(*, *) = aten::mul(%x, %y)
return (%5);
}

View File

@ -1,7 +1,7 @@
graph() {
%0 : int = prim::Constant[value=1]()
%1 : float = prim::Constant[value=5]()
%b : int = prim::FloatToInt(%1)
%3 : int = aten::add(%b, %0)
%0 : float = prim::Constant[value=5]()
%1 : int = prim::Constant[value=1]()
%b : int = prim::FloatToInt(%0)
%3 : int = aten::add(%b, %1)
return (%3);
}

View File

@ -1,7 +1,7 @@
graph() {
%0 : float = prim::Constant[value=1]()
%1 : int = prim::Constant[value=2]()
%b : float = prim::IntToFloat(%1)
%3 : float = aten::add(%b, %0)
%0 : int = prim::Constant[value=2]()
%1 : float = prim::Constant[value=1]()
%b : float = prim::IntToFloat(%0)
%3 : float = aten::add(%b, %1)
return (%3);
}

View File

@ -3075,6 +3075,33 @@ a")
y2 = torch.sum(x, dim=0)
self.assertEqual(y, y2)
def test_constant_pooling(self):
def func(cond):
a = 1
b = 4
c = 0
d = "abc"
e = "bcd"
f = "abc"
x = torch.ones([2])
y = x * 4
z = torch.ones([2])
if bool(cond):
c = b - a
else:
y = torch.rand(0)
if bool(cond):
y = torch.rand(1)
print(d, e, f, x, y, z)
b = b - a
return a, b, c, x
self.checkScript(func, torch.tensor([1]))
graph = torch.jit.script(func).graph
self.run_pass('constant_propagation', graph)
self.run_pass('constant_pooling', graph)
self.assertExpectedGraph(graph)
def test_literal(self):
def func1(a, b):
c = a, b

View File

@ -146,6 +146,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/import.cpp
${TORCH_SRC_DIR}/csrc/jit/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/constants.cpp
${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
${TORCH_SRC_DIR}/csrc/jit/ir.cpp
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
@ -153,6 +154,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/constant_propagation.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/constant_pooling.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/common_subexpression_elimination.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/create_autodiff_subgraphs.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inline_autodiff_subgraphs.cpp

View File

@ -10,6 +10,7 @@
#include "torch/csrc/jit/passes/annotate_effects.h"
#include "torch/csrc/jit/passes/batch_mm.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/erase_number_types.h"
@ -444,6 +445,7 @@ private:
void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
ConstantPooling(graph);
UnrollLoops(graph);
PeepholeOptimize(graph);
CheckInplace(graph);

View File

@ -15,6 +15,7 @@
#include "torch/csrc/jit/passes/erase_number_types.h"
#include "torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/passes/canonicalize.h"
@ -88,6 +89,7 @@ void initJITBindings(PyObject *module) {
.def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
return EliminateCommonSubexpression(g); // overload resolution
})
.def("_jit_pass_constant_pooling", ConstantPooling)
.def("_jit_pass_peephole", PeepholeOptimize)
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
return Canonicalize(g);

View File

@ -0,0 +1,113 @@
#include "torch/csrc/jit/ir.h"
#include <algorithm>
#include <unordered_map>
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/node_hashing.h"
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/hash.h"
namespace torch { namespace jit {
namespace {
bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
return &lhs.type() == &rhs.type() && lhs.equal(rhs);
}
bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
if (lhs.size() != rhs.size()) return false;
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
}
// Check whether two nodes have the same attributes in CSE.
// This function may be too conservative for general use.
// Do NOT support g/gs attributes.
bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
JIT_ASSERT(lhs != nullptr);
JIT_ASSERT(rhs != nullptr);
// One has attributes, the other does not.
if (lhs->hasAttributes() != rhs->hasAttributes()) return false;
// Neither has attributes.
if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true;
auto lnames = lhs->attributeNames();
auto rnames = rhs->attributeNames();
std::sort(lnames.begin(), lnames.end());
std::sort(rnames.begin(), rnames.end());
if (lnames != rnames) return false;
for (auto name : lnames) {
if (lhs->kindOf(name) != rhs->kindOf(name)) return false;
#define COMPARE_ATTRIBUTEVALUE(type) \
case AttributeKind::type: \
{ if (lhs->type(name) != rhs->type(name)) return false; } break;
switch(lhs->kindOf(name)) {
COMPARE_ATTRIBUTEVALUE(f)
COMPARE_ATTRIBUTEVALUE(fs)
COMPARE_ATTRIBUTEVALUE(i)
COMPARE_ATTRIBUTEVALUE(is)
COMPARE_ATTRIBUTEVALUE(s)
COMPARE_ATTRIBUTEVALUE(ss)
case AttributeKind::t: {
if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
break;
}
case AttributeKind::ts: {
if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
break;
}
case AttributeKind::g:
case AttributeKind::gs:
return false;
}
#undef COMPARE_ATTRIBUTEVALUE
}
return true;
}
} // anonymous namespace
size_t HashNode::operator()(const Node* k) const {
JIT_ASSERT(k != nullptr);
return get_hash(k->kind(),
fmap(k->outputs(), [](const Value *v) { return v->type()->kind(); }),
fmap(k->inputs(), [](const Value *v) { return v->unique(); }));
};
bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
if (lhs == nullptr && rhs == nullptr) return true;
if (lhs == nullptr || rhs == nullptr) return false;
if (lhs->kind() != rhs->kind()) return false;
// Check whether the output types are the same.
auto lhs_outputs = lhs->outputs();
auto rhs_outputs = rhs->outputs();
if (lhs_outputs.size() != rhs_outputs.size()) return false;
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
return false;
}
// Check whether the inputs are the same.
auto lhs_inputs = lhs->inputs();
auto rhs_inputs = rhs->inputs();
if (lhs_inputs.size() != rhs_inputs.size()) return false;
if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false;
if (!attributesEqualCSE(lhs, rhs)) return false;
return true;
};
}}

View File

@ -0,0 +1,15 @@
#pragma once
#include "torch/csrc/jit/ir.h"
namespace torch { namespace jit {
struct HashNode {
size_t operator()(const Node* k) const;
};
struct EqualNode {
bool operator()(const Node* lhs, const Node* rhs) const;
};
}}

View File

@ -6,118 +6,17 @@
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/node_hashing.h"
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/hash.h"
namespace torch { namespace jit {
namespace {
bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
return &lhs.type() == &rhs.type() && lhs.equal(rhs);
}
bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
if (lhs.size() != rhs.size()) return false;
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
}
// Check whether two nodes have the same attributes in CSE.
// This function may be too conservative for general use.
// Do NOT support t/ts/g/gs attributes.
// If t/ts are supported, CONSTANT node comparison may need to consider device.
bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
JIT_ASSERT(lhs != nullptr);
JIT_ASSERT(rhs != nullptr);
// One has attributes, the other does not.
if (lhs->hasAttributes() != rhs->hasAttributes()) return false;
// Neither has attributes.
if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true;
auto lnames = lhs->attributeNames();
auto rnames = rhs->attributeNames();
std::sort(lnames.begin(), lnames.end());
std::sort(rnames.begin(), rnames.end());
if (lnames != rnames) return false;
for (auto name : lnames) {
if (lhs->kindOf(name) != rhs->kindOf(name)) return false;
#define COMPARE_ATTRIBUTEVALUE(type) \
case AttributeKind::type: \
{ if (lhs->type(name) != rhs->type(name)) return false; } break;
switch(lhs->kindOf(name)) {
COMPARE_ATTRIBUTEVALUE(f)
COMPARE_ATTRIBUTEVALUE(fs)
COMPARE_ATTRIBUTEVALUE(i)
COMPARE_ATTRIBUTEVALUE(is)
COMPARE_ATTRIBUTEVALUE(s)
COMPARE_ATTRIBUTEVALUE(ss)
case AttributeKind::t: {
if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
break;
}
case AttributeKind::ts: {
if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
break;
}
case AttributeKind::g:
case AttributeKind::gs:
return false;
}
#undef COMPARE_ATTRIBUTEVALUE
}
return true;
}
struct HashNodeCSE {
size_t operator()(const Node* k) const {
JIT_ASSERT(k != nullptr);
return get_hash(k->kind(),
fmap(k->outputs(), [](const Value *v) { return v->type()->kind(); }),
fmap(k->inputs(), [](const Value *v) { return v->unique(); }));
}
};
struct EqualNodeCSE {
bool operator()(const Node* lhs, const Node* rhs) const {
if (lhs == nullptr && rhs == nullptr) return true;
if (lhs == nullptr || rhs == nullptr) return false;
if (lhs->kind() != rhs->kind()) return false;
// Check whether the output types are the same.
auto lhs_outputs = lhs->outputs();
auto rhs_outputs = rhs->outputs();
if (lhs_outputs.size() != rhs_outputs.size()) return false;
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
return false;
}
// Check whether the inputs are the same.
auto lhs_inputs = lhs->inputs();
auto rhs_inputs = rhs->inputs();
if (lhs_inputs.size() != rhs_inputs.size()) return false;
if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false;
if (!attributesEqualCSE(lhs, rhs)) return false;
return true;
}
};
} // anonymous namespace
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
void EliminateCommonSubexpression(Block * block,
std::function<Node*(Node*)> parent_lookup_fn) {
std::unordered_set<Node*, HashNodeCSE, EqualNodeCSE> subexprs;
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
auto node = *it;
if (node->kind() == prim::PythonOp

View File

@ -0,0 +1,53 @@
#include "torch/csrc/jit/ir.h"
#include <unordered_set>
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/node_hashing.h"
namespace torch { namespace jit {
namespace {
//Very similar to the common subexpression elimination pass
//Move all constants to the beginning of the graph, and deduplicate
void ConstantPooling(Block * block, std::unordered_set<Node*, HashNode, EqualNode>& constants) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
auto node = *it;
// node may be moved to a different block so advance iterator now
++it;
if (!node->blocks().empty()) {
// Traverse sub-blocks.
for (auto block : node->blocks()) {
ConstantPooling(block, constants);
}
continue;
}
if (node->kind() != prim::Constant) {
continue;
}
auto first_node = node->owningGraph()->block()->nodes().front();
if (node != first_node)
node->moveBefore(first_node);
// Check whether the same constant already exists.
auto subit = constants.insert(node);
if (!subit.second) {
// constant exists, replace the uses of node, and destroy it.
auto existing = *subit.first;
node->replaceAllUsesWith(existing);
node->destroy();
}
}
}
} // anonymous namespace
void ConstantPooling(const std::shared_ptr<Graph>& graph) {
std::unordered_set<Node*, HashNode, EqualNode> constants;
ConstantPooling(graph->block(), constants);
}
}}

View File

@ -0,0 +1,9 @@
#pragma once
#include "torch/csrc/jit/ir.h"
namespace torch { namespace jit {
TORCH_API void ConstantPooling(const std::shared_ptr<Graph>& graph);
}}

View File

@ -1,6 +1,7 @@
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/passes/annotate_effects.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/ir.h"
@ -903,6 +904,7 @@ struct to_ir {
AnnotateEffects(graph);
// remove any uses of tuples that we inserted that are not needed
LowerSimpleTuples(graph);
ConstantPooling(graph);
}
private: