mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
constant pooling pass (#12222)
Summary: Add a pass to move all constants to the beginning of the graph, and deduplicate. This extends https://github.com/pytorch/pytorch/pull/10231 to also handle constants introduced in inlining, constant propagation, etc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12222 Reviewed By: driazati Differential Revision: D10201616 Pulled By: eellison fbshipit-source-id: bc9c5be26868c8b5414257a0d4462de025aeb9bd
This commit is contained in:
committed by
Facebook Github Bot
parent
83b4dc6822
commit
00aedfc0e2
@ -4,19 +4,18 @@ graph(%x.1_data : Dynamic
|
||||
%y_data : Dynamic
|
||||
%y_mask : Dynamic
|
||||
%y_dims : Dynamic) {
|
||||
%6 : int = prim::Constant[value=10]()
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : bool = prim::Constant[value=1]()
|
||||
%x : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::Loop(%6, %7, %x.1_data, %x.1_mask, %x.1_dims)
|
||||
%8 : int = prim::Constant[value=10]()
|
||||
%x : Dynamic, %10 : Dynamic, %11 : Dynamic = prim::Loop(%8, %7, %x.1_data, %x.1_mask, %x.1_dims)
|
||||
block0(%loop_num : int, %5_data : Dynamic, %5_mask : Dynamic, %5_dims : Dynamic) {
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : Long() = prim::NumToTensor(%15)
|
||||
%16 : Long() = prim::NumToTensor(%6)
|
||||
%alpha : float = prim::TensorToNum(%16)
|
||||
%data.1 : Dynamic = aten::add(%5_data, %y_data, %alpha)
|
||||
%mask : Dynamic = aten::mul(%5_mask, %y_mask)
|
||||
%dims : Dynamic = aten::__or__(%5_dims, %y_dims)
|
||||
%21 : bool = prim::Constant[value=1]()
|
||||
%data : Dynamic = aten::where(%mask, %data.1, %5_data)
|
||||
-> (%21, %data, %mask, %dims)
|
||||
-> (%7, %data, %mask, %dims)
|
||||
}
|
||||
return (%x, %9, %10);
|
||||
return (%x, %10, %11);
|
||||
}
|
||||
|
@ -4,38 +4,36 @@ graph(%a.1_data : Dynamic
|
||||
%b_data : Dynamic
|
||||
%b_mask : Dynamic
|
||||
%b_dims : Dynamic) {
|
||||
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
|
||||
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%9 : bool = prim::TensorToBool(%6)
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Long() = prim::NumToTensor(%10)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
|
||||
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%10 : bool = prim::TensorToBool(%7)
|
||||
%11 : Long() = prim::NumToTensor(%6)
|
||||
%alpha.1 : float = prim::TensorToNum(%11)
|
||||
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
|
||||
%mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%16 : int = prim::Constant[value=1]()
|
||||
%17 : Long() = prim::NumToTensor(%16)
|
||||
%alpha : float = prim::TensorToNum(%17)
|
||||
%16 : Long() = prim::NumToTensor(%6)
|
||||
%alpha : float = prim::TensorToNum(%16)
|
||||
%data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
|
||||
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%21 : bool = prim::Constant[value=1]()
|
||||
%22 : int = prim::Constant[value=1]()
|
||||
%23 : Dynamic = aten::type_as(%7, %6)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%6, %23)
|
||||
%23 : Dynamic = aten::type_as(%8, %7)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%7, %23)
|
||||
%25 : int = aten::dim(%cond_mask.1)
|
||||
%26 : bool = aten::eq(%25, %22)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26)
|
||||
block0() {
|
||||
%30 : int = aten::dim(%data.1)
|
||||
%31 : int = aten::sub(%30, %22)
|
||||
%32 : bool = prim::Constant[value=1]()
|
||||
%data.3 : Dynamic = prim::Loop(%31, %32, %cond_mask.1)
|
||||
block0(%_ : int, %35 : Dynamic) {
|
||||
%36 : int = aten::dim(%35)
|
||||
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
|
||||
%38 : bool = prim::Constant[value=1]()
|
||||
-> (%38, %data.2)
|
||||
%data.3 : Dynamic = prim::Loop(%31, %21, %cond_mask.1)
|
||||
block0(%_ : int, %34 : Dynamic) {
|
||||
%35 : int = aten::dim(%34)
|
||||
%data.2 : Dynamic = aten::unsqueeze(%34, %35)
|
||||
-> (%21, %data.2)
|
||||
}
|
||||
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
|
||||
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
|
||||
|
@ -4,39 +4,37 @@ graph(%a.1_data : Dynamic
|
||||
%b_data : Dynamic
|
||||
%b_mask : Dynamic
|
||||
%b_dims : Dynamic) {
|
||||
%6 : float = prim::Constant[value=0.1]()
|
||||
%7 : Float() = prim::NumToTensor(%6)
|
||||
%other : float = prim::TensorToNum(%7)
|
||||
%9 : Dynamic = aten::gt(%a.1_data, %other)
|
||||
%10 : bool = prim::TensorToBool(%9)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Long() = prim::NumToTensor(%11)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : float = prim::Constant[value=0.1]()
|
||||
%8 : Float() = prim::NumToTensor(%7)
|
||||
%other : float = prim::TensorToNum(%8)
|
||||
%10 : Dynamic = aten::gt(%a.1_data, %other)
|
||||
%11 : bool = prim::TensorToBool(%10)
|
||||
%12 : Long() = prim::NumToTensor(%6)
|
||||
%alpha.1 : float = prim::TensorToNum(%12)
|
||||
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
|
||||
%mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Long() = prim::NumToTensor(%17)
|
||||
%alpha : float = prim::TensorToNum(%18)
|
||||
%17 : Long() = prim::NumToTensor(%6)
|
||||
%alpha : float = prim::TensorToNum(%17)
|
||||
%data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
|
||||
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%22 : bool = prim::Constant[value=1]()
|
||||
%23 : int = prim::Constant[value=1]()
|
||||
%24 : Dynamic = aten::type_as(%a.1_mask, %9)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%9, %24)
|
||||
%24 : Dynamic = aten::type_as(%a.1_mask, %10)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%10, %24)
|
||||
%26 : int = aten::dim(%cond_mask.1)
|
||||
%27 : bool = aten::eq(%26, %23)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27)
|
||||
block0() {
|
||||
%31 : int = aten::dim(%data.1)
|
||||
%32 : int = aten::sub(%31, %23)
|
||||
%33 : bool = prim::Constant[value=1]()
|
||||
%data.3 : Dynamic = prim::Loop(%32, %33, %cond_mask.1)
|
||||
block0(%_ : int, %36 : Dynamic) {
|
||||
%37 : int = aten::dim(%36)
|
||||
%data.2 : Dynamic = aten::unsqueeze(%36, %37)
|
||||
%39 : bool = prim::Constant[value=1]()
|
||||
-> (%39, %data.2)
|
||||
%data.3 : Dynamic = prim::Loop(%32, %22, %cond_mask.1)
|
||||
block0(%_ : int, %35 : Dynamic) {
|
||||
%36 : int = aten::dim(%35)
|
||||
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
|
||||
-> (%22, %data.2)
|
||||
}
|
||||
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
|
||||
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
|
||||
|
@ -4,32 +4,31 @@ graph(%a.1_data : Dynamic
|
||||
%b_data : Dynamic
|
||||
%b_mask : Dynamic
|
||||
%b_dims : Dynamic) {
|
||||
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
|
||||
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%9 : bool = prim::TensorToBool(%6)
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Long() = prim::NumToTensor(%10)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
|
||||
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%10 : bool = prim::TensorToBool(%7)
|
||||
%11 : Long() = prim::NumToTensor(%6)
|
||||
%alpha : float = prim::TensorToNum(%11)
|
||||
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha)
|
||||
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%16 : int = prim::Constant[value=1]()
|
||||
%17 : Dynamic = aten::type_as(%7, %6)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%6, %17)
|
||||
%19 : int = aten::dim(%cond_mask.1)
|
||||
%20 : bool = aten::eq(%19, %16)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%20)
|
||||
%16 : bool = prim::Constant[value=1]()
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Dynamic = aten::type_as(%8, %7)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%7, %18)
|
||||
%20 : int = aten::dim(%cond_mask.1)
|
||||
%21 : bool = aten::eq(%20, %17)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
|
||||
block0() {
|
||||
%24 : int = aten::dim(%data.1)
|
||||
%25 : int = aten::sub(%24, %16)
|
||||
%26 : bool = prim::Constant[value=1]()
|
||||
%data.3 : Dynamic = prim::Loop(%25, %26, %cond_mask.1)
|
||||
%25 : int = aten::dim(%data.1)
|
||||
%26 : int = aten::sub(%25, %17)
|
||||
%data.3 : Dynamic = prim::Loop(%26, %16, %cond_mask.1)
|
||||
block0(%_ : int, %29 : Dynamic) {
|
||||
%30 : int = aten::dim(%29)
|
||||
%data.2 : Dynamic = aten::unsqueeze(%29, %30)
|
||||
%32 : bool = prim::Constant[value=1]()
|
||||
-> (%32, %data.2)
|
||||
-> (%16, %data.2)
|
||||
}
|
||||
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
|
||||
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
|
||||
|
@ -4,33 +4,32 @@ graph(%a.1_data : Dynamic
|
||||
%b_data : Dynamic
|
||||
%b_mask : Dynamic
|
||||
%b_dims : Dynamic) {
|
||||
%6 : float = prim::Constant[value=0.1]()
|
||||
%7 : Float() = prim::NumToTensor(%6)
|
||||
%other : float = prim::TensorToNum(%7)
|
||||
%9 : Dynamic = aten::gt(%a.1_data, %other)
|
||||
%10 : bool = prim::TensorToBool(%9)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Long() = prim::NumToTensor(%11)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : float = prim::Constant[value=0.1]()
|
||||
%8 : Float() = prim::NumToTensor(%7)
|
||||
%other : float = prim::TensorToNum(%8)
|
||||
%10 : Dynamic = aten::gt(%a.1_data, %other)
|
||||
%11 : bool = prim::TensorToBool(%10)
|
||||
%12 : Long() = prim::NumToTensor(%6)
|
||||
%alpha : float = prim::TensorToNum(%12)
|
||||
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha)
|
||||
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Dynamic = aten::type_as(%a.1_mask, %9)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%9, %18)
|
||||
%20 : int = aten::dim(%cond_mask.1)
|
||||
%21 : bool = aten::eq(%20, %17)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
|
||||
%17 : bool = prim::Constant[value=1]()
|
||||
%18 : int = prim::Constant[value=1]()
|
||||
%19 : Dynamic = aten::type_as(%a.1_mask, %10)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%10, %19)
|
||||
%21 : int = aten::dim(%cond_mask.1)
|
||||
%22 : bool = aten::eq(%21, %18)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%22)
|
||||
block0() {
|
||||
%25 : int = aten::dim(%data.1)
|
||||
%26 : int = aten::sub(%25, %17)
|
||||
%27 : bool = prim::Constant[value=1]()
|
||||
%data.3 : Dynamic = prim::Loop(%26, %27, %cond_mask.1)
|
||||
%26 : int = aten::dim(%data.1)
|
||||
%27 : int = aten::sub(%26, %18)
|
||||
%data.3 : Dynamic = prim::Loop(%27, %17, %cond_mask.1)
|
||||
block0(%_ : int, %30 : Dynamic) {
|
||||
%31 : int = aten::dim(%30)
|
||||
%data.2 : Dynamic = aten::unsqueeze(%30, %31)
|
||||
%33 : bool = prim::Constant[value=1]()
|
||||
-> (%33, %data.2)
|
||||
-> (%17, %data.2)
|
||||
}
|
||||
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
|
||||
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
|
||||
|
@ -4,20 +4,20 @@ graph(%a.1_data : Dynamic
|
||||
%b_data : Dynamic
|
||||
%b_mask : Dynamic
|
||||
%b_dims : Dynamic) {
|
||||
%6 : int = prim::Constant[value=2147483647]()
|
||||
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
|
||||
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%10 : bool = prim::TensorToBool(%7)
|
||||
%11 : int = prim::Constant[value=0]()
|
||||
%12 : Dynamic = aten::mul(%7, %8)
|
||||
%13 : Dynamic = aten::sum(%12)
|
||||
%14 : Dynamic = aten::gt(%13, %11)
|
||||
%15 : bool = prim::TensorToBool(%14)
|
||||
%16 : Dynamic, %17 : Dynamic, %18 : Dynamic, %a : Dynamic, %20 : Dynamic, %21 : Dynamic = prim::Loop(%6, %15, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : int = prim::Constant[value=2147483647]()
|
||||
%8 : Dynamic = aten::gt(%a.1_data, %b_data)
|
||||
%9 : Dynamic = aten::mul(%a.1_mask, %b_mask)
|
||||
%10 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
|
||||
%11 : bool = prim::TensorToBool(%8)
|
||||
%12 : int = prim::Constant[value=0]()
|
||||
%13 : Dynamic = aten::mul(%8, %9)
|
||||
%14 : Dynamic = aten::sum(%13)
|
||||
%15 : Dynamic = aten::gt(%14, %12)
|
||||
%16 : bool = prim::TensorToBool(%15)
|
||||
%17 : Dynamic, %18 : Dynamic, %19 : Dynamic, %a : Dynamic, %21 : Dynamic, %22 : Dynamic = prim::Loop(%7, %16, %8, %9, %10, %a.1_data, %a.1_mask, %a.1_dims)
|
||||
block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%30 : Long() = prim::NumToTensor(%29)
|
||||
%30 : Long() = prim::NumToTensor(%6)
|
||||
%alpha : float = prim::TensorToNum(%30)
|
||||
%data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
|
||||
%mask : Dynamic = aten::mul(%6_mask, %b_mask)
|
||||
@ -26,22 +26,21 @@ graph(%a.1_data : Dynamic
|
||||
%36 : Dynamic = aten::mul(%mask, %b_mask)
|
||||
%37 : Dynamic = aten::__or__(%dims, %b_dims)
|
||||
%38 : bool = prim::TensorToBool(%35)
|
||||
%39 : int = prim::Constant[value=1]()
|
||||
%40 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %40)
|
||||
%42 : int = aten::dim(%cond_mask.1)
|
||||
%43 : bool = aten::eq(%42, %39)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%43)
|
||||
%39 : bool = prim::Constant[value=1]()
|
||||
%40 : int = prim::Constant[value=1]()
|
||||
%41 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
|
||||
%cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %41)
|
||||
%43 : int = aten::dim(%cond_mask.1)
|
||||
%44 : bool = aten::eq(%43, %40)
|
||||
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%44)
|
||||
block0() {
|
||||
%47 : int = aten::dim(%data.1)
|
||||
%48 : int = aten::sub(%47, %39)
|
||||
%49 : bool = prim::Constant[value=1]()
|
||||
%data.3 : Dynamic = prim::Loop(%48, %49, %cond_mask.1)
|
||||
%48 : int = aten::dim(%data.1)
|
||||
%49 : int = aten::sub(%48, %40)
|
||||
%data.3 : Dynamic = prim::Loop(%49, %39, %cond_mask.1)
|
||||
block0(%_ : int, %52 : Dynamic) {
|
||||
%53 : int = aten::dim(%52)
|
||||
%data.2 : Dynamic = aten::unsqueeze(%52, %53)
|
||||
%55 : bool = prim::Constant[value=1]()
|
||||
-> (%55, %data.2)
|
||||
-> (%39, %data.2)
|
||||
}
|
||||
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
|
||||
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
|
||||
@ -53,12 +52,12 @@ graph(%a.1_data : Dynamic
|
||||
%res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
|
||||
%res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
|
||||
%res_dims : Dynamic = aten::__or__(%dims, %6_dims)
|
||||
%61 : int = prim::Constant[value=0]()
|
||||
%62 : Dynamic = aten::mul(%35, %36)
|
||||
%63 : Dynamic = aten::sum(%62)
|
||||
%64 : Dynamic = aten::gt(%63, %61)
|
||||
%65 : bool = prim::TensorToBool(%64)
|
||||
-> (%65, %35, %36, %37, %res_data, %res_mask, %res_dims)
|
||||
%60 : int = prim::Constant[value=0]()
|
||||
%61 : Dynamic = aten::mul(%35, %36)
|
||||
%62 : Dynamic = aten::sum(%61)
|
||||
%63 : Dynamic = aten::gt(%62, %60)
|
||||
%64 : bool = prim::TensorToBool(%63)
|
||||
-> (%64, %35, %36, %37, %res_data, %res_mask, %res_dims)
|
||||
}
|
||||
return (%a, %20, %21);
|
||||
return (%a, %21, %22);
|
||||
}
|
||||
|
@ -21,11 +21,8 @@ graph(%a : Dynamic
|
||||
}
|
||||
%c0.5 : int = aten::add(%c0.4, %c2.1)
|
||||
%11 : int = prim::Constant[value=5]()
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%13 : Dynamic = aten::add(%a, %c0.5, %12)
|
||||
%14 : int = prim::Constant[value=1]()
|
||||
%15 : Dynamic = aten::add(%13, %c1, %14)
|
||||
%16 : int = prim::Constant[value=1]()
|
||||
%17 : Dynamic = aten::add(%15, %11, %16)
|
||||
return (%17);
|
||||
%12 : Dynamic = aten::add(%a, %c0.5, %c2.1)
|
||||
%13 : Dynamic = aten::add(%12, %c1, %c2.1)
|
||||
%14 : Dynamic = aten::add(%13, %11, %c2.1)
|
||||
return (%14);
|
||||
}
|
||||
|
@ -1,19 +1,17 @@
|
||||
graph() {
|
||||
%b.4 : int = prim::Constant[value=2]()
|
||||
%b.2 : int = prim::Constant[value=1]()
|
||||
%2 : int = prim::Constant[value=2147483647]()
|
||||
%0 : bool = prim::Constant[value=0]()
|
||||
%1 : bool = prim::Constant[value=1]()
|
||||
%b.1 : int = prim::Constant[value=0]()
|
||||
%4 : bool = prim::Constant[value=1]()
|
||||
%b.3 : int = prim::Loop(%2, %4, %b.1)
|
||||
block0(%6 : int, %7 : int) {
|
||||
%8 : bool = prim::Constant[value=1]()
|
||||
-> (%8, %b.2)
|
||||
%3 : int = prim::Constant[value=2147483647]()
|
||||
%b.2 : int = prim::Constant[value=1]()
|
||||
%b.4 : int = prim::Constant[value=2]()
|
||||
%b.3 : int = prim::Loop(%3, %1, %b.1)
|
||||
block0(%7 : int, %8 : int) {
|
||||
-> (%1, %b.2)
|
||||
}
|
||||
%9 : bool = prim::Constant[value=0]()
|
||||
%b : int = prim::Loop(%2, %9, %b.3)
|
||||
block0(%11 : int, %12 : int) {
|
||||
%13 : bool = prim::Constant[value=0]()
|
||||
-> (%13, %b.4)
|
||||
%b : int = prim::Loop(%3, %0, %b.3)
|
||||
block0(%10 : int, %11 : int) {
|
||||
-> (%0, %b.4)
|
||||
}
|
||||
return (%b);
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
graph(%input_tensor : Dynamic) {
|
||||
%1 : int = prim::Constant[value=6]()
|
||||
= prim::Print(%1)
|
||||
%2 : int = prim::Constant[value=8]()
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%4 : Dynamic = aten::add(%input_tensor, %2, %3)
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%2 : int = prim::Constant[value=6]()
|
||||
= prim::Print(%2)
|
||||
%3 : int = prim::Constant[value=8]()
|
||||
%4 : Dynamic = aten::add(%input_tensor, %3, %1)
|
||||
return (%4);
|
||||
}
|
||||
|
@ -1,11 +1,11 @@
|
||||
graph() {
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%1 : int[] = prim::Constant[value=[3]]()
|
||||
%2 : int = prim::Constant[value=6]()
|
||||
%3 : int = prim::Constant[value=0]()
|
||||
%4 : int[] = prim::Constant[value=[0, -1]]()
|
||||
%a : Dynamic = aten::randn(%1, %2, %3, %4)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%b : Dynamic = aten::add(%a, %0, %6)
|
||||
%0 : int = prim::Constant[value=1]()
|
||||
%1 : int[] = prim::Constant[value=[0, -1]]()
|
||||
%2 : int = prim::Constant[value=0]()
|
||||
%3 : int = prim::Constant[value=6]()
|
||||
%4 : int = prim::Constant[value=2]()
|
||||
%5 : int[] = prim::Constant[value=[3]]()
|
||||
%a : Dynamic = aten::randn(%5, %3, %2, %1)
|
||||
%b : Dynamic = aten::add(%a, %4, %0)
|
||||
return (%b);
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
graph(%input_tensor : Dynamic) {
|
||||
%1 : int = prim::Constant[value=8]()
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Dynamic = aten::add(%input_tensor, %1, %2)
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%2 : int = prim::Constant[value=8]()
|
||||
%3 : Dynamic = aten::add(%input_tensor, %2, %1)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -3,23 +3,21 @@ graph(%mat : Dynamic
|
||||
%mat2 : Dynamic
|
||||
%alpha : Dynamic
|
||||
%beta : Dynamic) {
|
||||
%5 : float = prim::Constant[value=2]()
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%6 : float = prim::Constant[value=4.2]()
|
||||
%7 : Dynamic = aten::mm(%mat1, %mat2)
|
||||
%8 : int = prim::Constant[value=1]()
|
||||
%9 : Dynamic = aten::add(%mat, %7, %8)
|
||||
%10 : Dynamic = aten::mm(%mat1, %mat2)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Dynamic = aten::add(%mat, %10, %11)
|
||||
%c : Dynamic = aten::addmm(%mat, %mat1, %mat2, %5, %6)
|
||||
%14 : int = prim::TensorToNum(%alpha)
|
||||
%15 : int = prim::TensorToNum(%beta)
|
||||
%d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %15, %14)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Dynamic = aten::add(%9, %12, %17)
|
||||
%19 : int = prim::Constant[value=1]()
|
||||
%20 : Dynamic = aten::add(%18, %c, %19)
|
||||
%21 : int = prim::Constant[value=1]()
|
||||
%22 : Dynamic = aten::add(%20, %d, %21)
|
||||
return (%22);
|
||||
%7 : float = prim::Constant[value=2]()
|
||||
%8 : Dynamic = aten::mm(%mat1, %mat2)
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%10 : Dynamic = aten::add(%mat, %8, %9)
|
||||
%11 : Dynamic = aten::mm(%mat1, %mat2)
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%13 : Dynamic = aten::add(%mat, %11, %12)
|
||||
%c : Dynamic = aten::addmm(%mat, %mat1, %mat2, %7, %6)
|
||||
%15 : int = prim::TensorToNum(%alpha)
|
||||
%16 : int = prim::TensorToNum(%beta)
|
||||
%d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %16, %15)
|
||||
%18 : Dynamic = aten::add(%10, %13, %5)
|
||||
%19 : Dynamic = aten::add(%18, %c, %5)
|
||||
%20 : Dynamic = aten::add(%19, %d, %5)
|
||||
return (%20);
|
||||
}
|
||||
|
@ -6,27 +6,26 @@ ModelProto {
|
||||
GraphProto {
|
||||
name: "torch-jit-export"
|
||||
inputs: [{name: "y.1", type:Tensor dims: 3 4 1}]
|
||||
outputs: [{name: "5", type:Tensor dims: 3 4 4}]
|
||||
outputs: [{name: "6", type:Tensor dims: 3 4 4}]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Loop", inputs: [3,4,y.1], outputs: [5], attributes: [{ name: 'body', type: graph, value:
|
||||
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Loop", inputs: [3,2,y.1], outputs: [6], attributes: [{ name: 'body', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export1"
|
||||
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "8", type:Tensor dims: }]
|
||||
outputs: [{name: "15", type:Tensor dims: },{name: "16", type:Tensor dims: }]
|
||||
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "9", type:Tensor dims: }]
|
||||
outputs: [{name: "2", type:Tensor dims: },{name: "15", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Unsqueeze", inputs: [2], outputs: [9], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [1], outputs: [10], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [i], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Concat", inputs: [9,10,11], outputs: [12], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [13], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "ATen", inputs: [y.1,12,13], outputs: [16], attributes: [{ name: 'operator', type: string, value: 'expand'}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [15], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
|
||||
Node {type: "Unsqueeze", inputs: [4], outputs: [10], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [5], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [i], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Concat", inputs: [10,11,12], outputs: [13], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "ATen", inputs: [y.1,13,1], outputs: [15], attributes: [{ name: 'operator', type: string, value: 'expand'}]}
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -5,13 +5,13 @@ graph(%0 : Float(*, *)
|
||||
%4 : Float(*, *)
|
||||
%5 : Float(*)
|
||||
%6 : Float(*)) {
|
||||
%7 : Float(*, *) = aten::t(%3)
|
||||
%8 : Float(*, *) = aten::mm(%0, %7)
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%10 : Float(*, *) = aten::add(%5, %8, %9)
|
||||
%7 : int = prim::Constant[value=1]()
|
||||
%8 : Float(*, *) = aten::t(%3)
|
||||
%9 : Float(*, *) = aten::mm(%0, %8)
|
||||
%10 : Float(*, *) = aten::add(%5, %9, %7)
|
||||
%11 : Float(*, *) = aten::t(%4)
|
||||
%12 : Float(*, *) = aten::mm(%1, %11)
|
||||
%13 : Float(*, *) = aten::add(%6, %12, %9)
|
||||
%13 : Float(*, *) = aten::add(%6, %12, %7)
|
||||
%14 : Dynamic[] = prim::ListConstruct(%10, %13)
|
||||
%15 : Dynamic[] = aten::broadcast_tensors(%14)
|
||||
%16 : Dynamic, %17 : Dynamic = prim::ListUnpack(%15)
|
||||
|
@ -5,13 +5,13 @@ graph(%0 : Float(*, *)
|
||||
%4 : Float(*, *)
|
||||
%5 : Float(*)
|
||||
%6 : Float(*)) {
|
||||
%7 : Float(*, *) = aten::t(%3)
|
||||
%8 : Float(*, *) = aten::mm(%0, %7)
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%10 : Float(*, *) = aten::add(%5, %8, %9)
|
||||
%7 : int = prim::Constant[value=1]()
|
||||
%8 : Float(*, *) = aten::t(%3)
|
||||
%9 : Float(*, *) = aten::mm(%0, %8)
|
||||
%10 : Float(*, *) = aten::add(%5, %9, %7)
|
||||
%11 : Float(*, *) = aten::t(%4)
|
||||
%12 : Float(*, *) = aten::mm(%1, %11)
|
||||
%13 : Float(*, *) = aten::add(%6, %12, %9)
|
||||
%13 : Float(*, *) = aten::add(%6, %12, %7)
|
||||
%14 : Dynamic[] = prim::ListConstruct(%10, %13)
|
||||
%15 : Dynamic[] = aten::broadcast_tensors(%14)
|
||||
%16 : Dynamic, %17 : Dynamic = prim::ListUnpack(%15)
|
||||
|
@ -1,7 +1,6 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%2 : Dynamic = ^python_fn()(%x)
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%4 : Dynamic = aten::add(%2, %1, %3)
|
||||
return (%4);
|
||||
%3 : Dynamic = aten::add(%2, %1, %1)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
graph(%x : Dynamic) {
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%1 : Dynamic = ^<python_value>()(%x)
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%4 : Dynamic = aten::add(%1, %2, %3)
|
||||
%4 : Dynamic = aten::add(%1, %2, %2)
|
||||
return (%4);
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%2 : Dynamic = aten::neg(%x)
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%4 : Dynamic = aten::add(%2, %1, %3)
|
||||
return (%4);
|
||||
%3 : Dynamic = aten::add(%2, %1, %1)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -1,14 +1,13 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%2 : int = prim::Constant[value=3]()
|
||||
%3 : int = prim::Constant[value=4]()
|
||||
%4 : int[] = prim::ListConstruct(%3, %2)
|
||||
%5 : int = prim::Constant[value=6]()
|
||||
%6 : int = prim::Constant[value=0]()
|
||||
%7 : int[] = prim::Constant[value=[0, -1]]()
|
||||
%8 : Dynamic = aten::zeros(%4, %5, %6, %7)
|
||||
%1 : int = prim::Constant[value=3]()
|
||||
%2 : int = prim::Constant[value=4]()
|
||||
%3 : int = prim::Constant[value=6]()
|
||||
%4 : int = prim::Constant[value=0]()
|
||||
%5 : int[] = prim::Constant[value=[0, -1]]()
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : int[] = prim::ListConstruct(%2, %1)
|
||||
%8 : Dynamic = aten::zeros(%7, %3, %4, %5)
|
||||
%9 : Dynamic = aten::mm(%x, %8)
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Dynamic = aten::add(%9, %1, %10)
|
||||
return (%11);
|
||||
%10 : Dynamic = aten::add(%9, %6, %6)
|
||||
return (%10);
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
graph(%x : Dynamic) {
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%1 : Double(3, 4) = aten::neg(%x)
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%4 : Dynamic = aten::add(%1, %2, %3)
|
||||
%4 : Dynamic = aten::add(%1, %2, %2)
|
||||
return (%4);
|
||||
}
|
||||
|
@ -1,14 +1,13 @@
|
||||
graph(%x : Dynamic) {
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%1 : int = prim::Constant[value=4]()
|
||||
%2 : int = prim::Constant[value=3]()
|
||||
%3 : int[] = prim::ListConstruct(%1, %2)
|
||||
%4 : int = prim::Constant[value=7]()
|
||||
%5 : int = prim::Constant[value=0]()
|
||||
%6 : int[] = prim::Constant[value=[0, -1]]()
|
||||
%5 : int = prim::Constant[value=0]()
|
||||
%4 : int = prim::Constant[value=7]()
|
||||
%2 : int = prim::Constant[value=3]()
|
||||
%1 : int = prim::Constant[value=4]()
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%3 : int[] = prim::ListConstruct(%1, %2)
|
||||
%7 : Double(4, 3) = aten::zeros(%3, %4, %5, %6)
|
||||
%8 : Double(3, 3) = aten::mm(%x, %7)
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Dynamic = aten::add(%8, %9, %10)
|
||||
%11 : Dynamic = aten::add(%8, %9, %9)
|
||||
return (%11);
|
||||
}
|
||||
|
33
test/expect/TestScript.test_constant_pooling.expect
Normal file
33
test/expect/TestScript.test_constant_pooling.expect
Normal file
@ -0,0 +1,33 @@
|
||||
graph(%cond : Dynamic) {
|
||||
%1 : int[] = prim::Constant[value=[1]]()
|
||||
%2 : int[] = prim::Constant[value=[0]]()
|
||||
%3 : int = prim::Constant[value=3]()
|
||||
%4 : Float(2) = prim::Constant[value= 4 4 [ CPUFloatType{2} ]]()
|
||||
%5 : Float(2) = prim::Constant[value= 1 1 [ CPUFloatType{2} ]]()
|
||||
%c.1 : int = prim::Constant[value=0]()
|
||||
%a : int = prim::Constant[value=1]()
|
||||
%d : string = prim::Constant[value="abc"]()
|
||||
%e : string = prim::Constant[value="bcd"]()
|
||||
%10 : int = prim::Constant[value=6]()
|
||||
%11 : int[] = prim::Constant[value=[0, -1]]()
|
||||
%12 : bool = prim::TensorToBool(%cond)
|
||||
%c : int, %y : Dynamic = prim::If(%12)
|
||||
block0() {
|
||||
-> (%3, %4)
|
||||
}
|
||||
block1() {
|
||||
%y.2 : Dynamic = aten::rand(%2, %10, %c.1, %11)
|
||||
%16 : bool = prim::TensorToBool(%cond)
|
||||
%y.4 : Dynamic = prim::If(%16)
|
||||
block0() {
|
||||
%y.3 : Dynamic = aten::rand(%1, %10, %c.1, %11)
|
||||
-> (%y.3)
|
||||
}
|
||||
block1() {
|
||||
-> (%y.2)
|
||||
}
|
||||
= prim::Print(%d, %e, %d, %5, %y.4, %5)
|
||||
-> (%c.1, %y.4)
|
||||
}
|
||||
return (%a, %3, %c, %5);
|
||||
}
|
@ -1,12 +1,10 @@
|
||||
graph(%a : Dynamic) {
|
||||
%1 : Long() = prim::Constant[value={3}]()
|
||||
%1 : Long() = prim::Constant[value={7}]()
|
||||
%2 : Long() = prim::Constant[value={1}]()
|
||||
%3 : Long() = prim::Constant[value={7}]()
|
||||
%4 : Long() = aten::add(%3, %2)
|
||||
%b : Long() = aten::add(%4, %1)
|
||||
%6 : Long() = prim::Constant[value={1}]()
|
||||
%c.1 : Dynamic = aten::add(%a, %b, %6)
|
||||
%8 : Long() = prim::Constant[value={1}]()
|
||||
%c : Dynamic = aten::add(%c.1, %b, %8)
|
||||
%3 : Long() = prim::Constant[value={3}]()
|
||||
%4 : Long() = aten::add(%1, %2)
|
||||
%b : Long() = aten::add(%4, %3)
|
||||
%c.1 : Dynamic = aten::add(%a, %b, %2)
|
||||
%c : Dynamic = aten::add(%c.1, %b, %2)
|
||||
return (%c);
|
||||
}
|
||||
|
@ -11,25 +11,23 @@ ModelProto {
|
||||
nodes: [
|
||||
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Gather", inputs: [x,2], outputs: [3], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "Shape", inputs: [x], outputs: [4], attributes: []},
|
||||
Node {type: "Gather", inputs: [4,1], outputs: [5], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Loop", inputs: [5,6,3], outputs: [7], attributes: [{ name: 'body', type: graph, value:
|
||||
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Gather", inputs: [x,2], outputs: [4], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "Shape", inputs: [x], outputs: [5], attributes: []},
|
||||
Node {type: "Gather", inputs: [5,3], outputs: [6], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "Loop", inputs: [6,1,4], outputs: [7], attributes: [{ name: 'body', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export1"
|
||||
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "10", type:Tensor dims: }]
|
||||
outputs: [{name: "18", type:Tensor dims: },{name: "17", type:Tensor dims: }]
|
||||
outputs: [{name: "1", type:Tensor dims: },{name: "16", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Constant", inputs: [], outputs: [11], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Unsqueeze", inputs: [2], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [i], outputs: [13], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [11], outputs: [14], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "DynamicSlice", inputs: [x,12,13,14], outputs: [15], attributes: []},
|
||||
Node {type: "ReduceSum", inputs: [15], outputs: [16], attributes: [{ name: 'axes', type: ints, values: [0]},{ name: 'keepdims', type: int, value: 0}]},
|
||||
Node {type: "Add", inputs: [10,16], outputs: [17], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [18], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
|
||||
Node {type: "Unsqueeze", inputs: [2], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [i], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "Unsqueeze", inputs: [2], outputs: [13], attributes: [{ name: 'axes', type: ints, values: [0]}]},
|
||||
Node {type: "DynamicSlice", inputs: [x,11,12,13], outputs: [14], attributes: []},
|
||||
Node {type: "ReduceSum", inputs: [14], outputs: [15], attributes: [{ name: 'axes', type: ints, values: [0]},{ name: 'keepdims', type: int, value: 0}]},
|
||||
Node {type: "Add", inputs: [10,15], outputs: [16], attributes: []}
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
graph(%x : Double(*, *)) {
|
||||
%1 : bool = prim::Constant[value=1]()
|
||||
%c : Dynamic[] = prim::If(%1)
|
||||
%1 : int = prim::Constant[value=0]()
|
||||
%2 : bool = prim::Constant[value=1]()
|
||||
%c : Dynamic[] = prim::If(%2)
|
||||
block0() {
|
||||
%c.1 : Dynamic[] = prim::ListConstruct(%x, %x)
|
||||
-> (%c.1)
|
||||
@ -9,7 +10,6 @@ graph(%x : Double(*, *)) {
|
||||
%c.2 : Dynamic[] = prim::ListConstruct(%x, %x, %x)
|
||||
-> (%c.2)
|
||||
}
|
||||
%5 : int = prim::Constant[value=0]()
|
||||
%6 : Dynamic = aten::cat(%c, %5)
|
||||
%6 : Dynamic = aten::cat(%c, %1)
|
||||
return (%6);
|
||||
}
|
||||
|
@ -1,34 +1,32 @@
|
||||
graph(%t : Dynamic) {
|
||||
%c1.2 : int = prim::Constant[value=0]()
|
||||
%1 : bool = prim::Constant[value=1]()
|
||||
%2 : bool = prim::Constant[value=0]()
|
||||
%c1.1 : int = prim::Constant[value=1]()
|
||||
%3 : bool = prim::Constant[value=0]()
|
||||
%4 : bool = prim::If(%3)
|
||||
%c1.2 : int = prim::Constant[value=0]()
|
||||
%5 : bool = prim::If(%2)
|
||||
block0() {
|
||||
%5 : int = prim::Constant[value=0]()
|
||||
%6 : Dynamic = aten::select(%t, %5, %c1.1)
|
||||
%6 : Dynamic = aten::select(%t, %c1.2, %c1.1)
|
||||
%7 : bool = prim::TensorToBool(%6)
|
||||
-> (%7)
|
||||
}
|
||||
block1() {
|
||||
-> (%3)
|
||||
-> (%2)
|
||||
}
|
||||
%8 : bool = prim::If(%4)
|
||||
%8 : bool = prim::If(%5)
|
||||
block0() {
|
||||
-> (%4)
|
||||
-> (%5)
|
||||
}
|
||||
block1() {
|
||||
%9 : bool = prim::Constant[value=1]()
|
||||
%10 : bool = prim::If(%9)
|
||||
%9 : bool = prim::If(%1)
|
||||
block0() {
|
||||
-> (%9)
|
||||
-> (%1)
|
||||
}
|
||||
block1() {
|
||||
%11 : int = prim::Constant[value=0]()
|
||||
%12 : Dynamic = aten::select(%t, %11, %c1.1)
|
||||
%13 : bool = prim::TensorToBool(%12)
|
||||
-> (%13)
|
||||
%10 : Dynamic = aten::select(%t, %c1.2, %c1.1)
|
||||
%11 : bool = prim::TensorToBool(%10)
|
||||
-> (%11)
|
||||
}
|
||||
-> (%10)
|
||||
-> (%9)
|
||||
}
|
||||
%c1 : int = prim::If(%8)
|
||||
block0() {
|
||||
|
@ -1,31 +1,29 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%1 : bool = prim::Constant[value=1]()
|
||||
%y.1 : int = prim::Constant[value=0]()
|
||||
%3 : int = prim::TensorToNum(%x)
|
||||
%4 : bool = prim::Constant[value=1]()
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%4 : int = prim::TensorToNum(%x)
|
||||
%5 : int = prim::Constant[value=8]()
|
||||
%6 : int = aten::floordiv(%3, %5)
|
||||
%6 : int = aten::floordiv(%4, %5)
|
||||
%7 : int = prim::Constant[value=8]()
|
||||
%8 : int = aten::mul(%6, %7)
|
||||
%9 : int = aten::sub(%3, %8)
|
||||
%y.3 : int = prim::Loop(%6, %4, %y.1)
|
||||
%9 : int = aten::sub(%4, %8)
|
||||
%y.3 : int = prim::Loop(%6, %1, %y.1)
|
||||
block0(%i.1 : int, %12 : int) {
|
||||
%y.12 : int = aten::add(%12, %1)
|
||||
%y.5 : int = aten::add(%y.12, %1)
|
||||
%y.6 : int = aten::add(%y.5, %1)
|
||||
%y.7 : int = aten::add(%y.6, %1)
|
||||
%y.8 : int = aten::add(%y.7, %1)
|
||||
%y.9 : int = aten::add(%y.8, %1)
|
||||
%y.10 : int = aten::add(%y.9, %1)
|
||||
%y.11 : int = aten::add(%y.10, %1)
|
||||
%21 : bool = prim::Constant[value=1]()
|
||||
-> (%21, %y.11)
|
||||
%y.12 : int = aten::add(%12, %3)
|
||||
%y.5 : int = aten::add(%y.12, %3)
|
||||
%y.6 : int = aten::add(%y.5, %3)
|
||||
%y.7 : int = aten::add(%y.6, %3)
|
||||
%y.8 : int = aten::add(%y.7, %3)
|
||||
%y.9 : int = aten::add(%y.8, %3)
|
||||
%y.10 : int = aten::add(%y.9, %3)
|
||||
%y.11 : int = aten::add(%y.10, %3)
|
||||
-> (%1, %y.11)
|
||||
}
|
||||
%y : int = prim::Loop(%9, %4, %y.3)
|
||||
block0(%i : int, %24 : int) {
|
||||
%y.4 : int = aten::add(%24, %1)
|
||||
%26 : bool = prim::Constant[value=1]()
|
||||
-> (%26, %y.4)
|
||||
%y : int = prim::Loop(%9, %1, %y.3)
|
||||
block0(%i : int, %23 : int) {
|
||||
%y.4 : int = aten::add(%23, %3)
|
||||
-> (%1, %y.4)
|
||||
}
|
||||
return (%y);
|
||||
}
|
||||
|
@ -1,14 +1,14 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : bool = prim::Constant[value=1]()
|
||||
%y.1 : int = prim::Constant[value=0]()
|
||||
%2 : int = prim::TensorToNum(%x)
|
||||
%3 : bool = prim::Constant[value=1]()
|
||||
%3 : int = prim::TensorToNum(%x)
|
||||
%4 : int = prim::Constant[value=0]()
|
||||
%5 : int = prim::Constant[value=8]()
|
||||
%6 : int = aten::floordiv(%2, %5)
|
||||
%6 : int = aten::floordiv(%3, %5)
|
||||
%7 : int = prim::Constant[value=8]()
|
||||
%8 : int = aten::mul(%6, %7)
|
||||
%9 : int = aten::sub(%2, %8)
|
||||
%10 : Dynamic, %y.3 : int = prim::Loop(%6, %3, %4, %y.1)
|
||||
%9 : int = aten::sub(%3, %8)
|
||||
%10 : Dynamic, %y.3 : int = prim::Loop(%6, %1, %4, %y.1)
|
||||
block0(%i.1 : int, %13 : int, %14 : int) {
|
||||
%y.12 : int = aten::add(%14, %13)
|
||||
%16 : int = prim::Constant[value=1]()
|
||||
@ -32,18 +32,16 @@ graph(%x : Dynamic) {
|
||||
%34 : int = prim::Constant[value=1]()
|
||||
%35 : int = aten::add(%32, %34)
|
||||
%y.11 : int = aten::add(%y.10, %35)
|
||||
%37 : bool = prim::Constant[value=1]()
|
||||
%38 : int = prim::Constant[value=1]()
|
||||
%39 : int = aten::add(%35, %38)
|
||||
-> (%37, %39, %y.11)
|
||||
%37 : int = prim::Constant[value=1]()
|
||||
%38 : int = aten::add(%35, %37)
|
||||
-> (%1, %38, %y.11)
|
||||
}
|
||||
%40 : Dynamic, %y : int = prim::Loop(%9, %3, %10, %y.3)
|
||||
block0(%i : int, %43 : int, %44 : int) {
|
||||
%y.4 : int = aten::add(%44, %43)
|
||||
%46 : bool = prim::Constant[value=1]()
|
||||
%47 : int = prim::Constant[value=1]()
|
||||
%48 : int = aten::add(%43, %47)
|
||||
-> (%46, %48, %y.4)
|
||||
%39 : Dynamic, %y : int = prim::Loop(%9, %1, %10, %y.3)
|
||||
block0(%i : int, %42 : int, %43 : int) {
|
||||
%y.4 : int = aten::add(%43, %42)
|
||||
%45 : int = prim::Constant[value=1]()
|
||||
%46 : int = aten::add(%42, %45)
|
||||
-> (%1, %46, %y.4)
|
||||
}
|
||||
return (%y);
|
||||
}
|
||||
|
@ -1,15 +1,15 @@
|
||||
graph() {
|
||||
%0 : int = prim::Constant[value=1]()
|
||||
%y.1 : int = prim::Constant[value=0]()
|
||||
%y.11 : int = aten::add(%y.1, %0)
|
||||
%y.2 : int = aten::add(%y.11, %0)
|
||||
%y.3 : int = aten::add(%y.2, %0)
|
||||
%y.4 : int = aten::add(%y.3, %0)
|
||||
%y.5 : int = aten::add(%y.4, %0)
|
||||
%y.6 : int = aten::add(%y.5, %0)
|
||||
%y.7 : int = aten::add(%y.6, %0)
|
||||
%y.8 : int = aten::add(%y.7, %0)
|
||||
%y.9 : int = aten::add(%y.8, %0)
|
||||
%y.10 : int = aten::add(%y.9, %0)
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%y.11 : int = aten::add(%y.1, %1)
|
||||
%y.2 : int = aten::add(%y.11, %1)
|
||||
%y.3 : int = aten::add(%y.2, %1)
|
||||
%y.4 : int = aten::add(%y.3, %1)
|
||||
%y.5 : int = aten::add(%y.4, %1)
|
||||
%y.6 : int = aten::add(%y.5, %1)
|
||||
%y.7 : int = aten::add(%y.6, %1)
|
||||
%y.8 : int = aten::add(%y.7, %1)
|
||||
%y.9 : int = aten::add(%y.8, %1)
|
||||
%y.10 : int = aten::add(%y.9, %1)
|
||||
return (%y.10);
|
||||
}
|
||||
|
@ -1,56 +1,52 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=10]()
|
||||
%1 : bool = prim::Constant[value=1]()
|
||||
%y.1 : int = prim::Constant[value=0]()
|
||||
%3 : bool = prim::Constant[value=1]()
|
||||
%y : int = prim::Loop(%1, %3, %y.1)
|
||||
%3 : int = prim::Constant[value=10]()
|
||||
%y : int = prim::Loop(%3, %1, %y.1)
|
||||
block0(%i : int, %6 : int) {
|
||||
%7 : int = prim::TensorToNum(%x)
|
||||
%8 : bool = prim::Constant[value=1]()
|
||||
%9 : int = prim::Constant[value=0]()
|
||||
%10 : int = prim::Constant[value=8]()
|
||||
%11 : int = aten::floordiv(%7, %10)
|
||||
%12 : int = prim::Constant[value=8]()
|
||||
%13 : int = aten::mul(%11, %12)
|
||||
%14 : int = aten::sub(%7, %13)
|
||||
%15 : Dynamic, %y.4 : int = prim::Loop(%11, %8, %9, %6)
|
||||
block0(%j.1 : int, %18 : int, %19 : int) {
|
||||
%y.13 : int = aten::add(%19, %18)
|
||||
%21 : int = prim::Constant[value=1]()
|
||||
%22 : int = aten::add(%18, %21)
|
||||
%y.6 : int = aten::add(%y.13, %22)
|
||||
%24 : int = prim::Constant[value=1]()
|
||||
%25 : int = aten::add(%22, %24)
|
||||
%y.7 : int = aten::add(%y.6, %25)
|
||||
%27 : int = prim::Constant[value=1]()
|
||||
%28 : int = aten::add(%25, %27)
|
||||
%y.8 : int = aten::add(%y.7, %28)
|
||||
%30 : int = prim::Constant[value=1]()
|
||||
%31 : int = aten::add(%28, %30)
|
||||
%y.9 : int = aten::add(%y.8, %31)
|
||||
%33 : int = prim::Constant[value=1]()
|
||||
%34 : int = aten::add(%31, %33)
|
||||
%y.10 : int = aten::add(%y.9, %34)
|
||||
%36 : int = prim::Constant[value=1]()
|
||||
%37 : int = aten::add(%34, %36)
|
||||
%y.11 : int = aten::add(%y.10, %37)
|
||||
%39 : int = prim::Constant[value=1]()
|
||||
%40 : int = aten::add(%37, %39)
|
||||
%y.12 : int = aten::add(%y.11, %40)
|
||||
%42 : bool = prim::Constant[value=1]()
|
||||
%43 : int = prim::Constant[value=1]()
|
||||
%44 : int = aten::add(%40, %43)
|
||||
-> (%42, %44, %y.12)
|
||||
%8 : int = prim::Constant[value=0]()
|
||||
%9 : int = prim::Constant[value=8]()
|
||||
%10 : int = aten::floordiv(%7, %9)
|
||||
%11 : int = prim::Constant[value=8]()
|
||||
%12 : int = aten::mul(%10, %11)
|
||||
%13 : int = aten::sub(%7, %12)
|
||||
%14 : Dynamic, %y.4 : int = prim::Loop(%10, %1, %8, %6)
|
||||
block0(%j.1 : int, %17 : int, %18 : int) {
|
||||
%y.13 : int = aten::add(%18, %17)
|
||||
%20 : int = prim::Constant[value=1]()
|
||||
%21 : int = aten::add(%17, %20)
|
||||
%y.6 : int = aten::add(%y.13, %21)
|
||||
%23 : int = prim::Constant[value=1]()
|
||||
%24 : int = aten::add(%21, %23)
|
||||
%y.7 : int = aten::add(%y.6, %24)
|
||||
%26 : int = prim::Constant[value=1]()
|
||||
%27 : int = aten::add(%24, %26)
|
||||
%y.8 : int = aten::add(%y.7, %27)
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%30 : int = aten::add(%27, %29)
|
||||
%y.9 : int = aten::add(%y.8, %30)
|
||||
%32 : int = prim::Constant[value=1]()
|
||||
%33 : int = aten::add(%30, %32)
|
||||
%y.10 : int = aten::add(%y.9, %33)
|
||||
%35 : int = prim::Constant[value=1]()
|
||||
%36 : int = aten::add(%33, %35)
|
||||
%y.11 : int = aten::add(%y.10, %36)
|
||||
%38 : int = prim::Constant[value=1]()
|
||||
%39 : int = aten::add(%36, %38)
|
||||
%y.12 : int = aten::add(%y.11, %39)
|
||||
%41 : int = prim::Constant[value=1]()
|
||||
%42 : int = aten::add(%39, %41)
|
||||
-> (%1, %42, %y.12)
|
||||
}
|
||||
%45 : Dynamic, %y.3 : int = prim::Loop(%14, %8, %15, %y.4)
|
||||
block0(%j : int, %48 : int, %49 : int) {
|
||||
%y.5 : int = aten::add(%49, %48)
|
||||
%51 : bool = prim::Constant[value=1]()
|
||||
%52 : int = prim::Constant[value=1]()
|
||||
%53 : int = aten::add(%48, %52)
|
||||
-> (%51, %53, %y.5)
|
||||
%43 : Dynamic, %y.3 : int = prim::Loop(%13, %1, %14, %y.4)
|
||||
block0(%j : int, %46 : int, %47 : int) {
|
||||
%y.5 : int = aten::add(%47, %46)
|
||||
%49 : int = prim::Constant[value=1]()
|
||||
%50 : int = aten::add(%46, %49)
|
||||
-> (%1, %50, %y.5)
|
||||
}
|
||||
%54 : bool = prim::Constant[value=1]()
|
||||
-> (%54, %y.3)
|
||||
-> (%1, %y.3)
|
||||
}
|
||||
return (%y);
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : float = prim::Constant[value=3.1]()
|
||||
%2 : float = prim::Constant[value=1.1]()
|
||||
%3 : float = aten::add(%2, %1)
|
||||
%1 : float = prim::Constant[value=1.1]()
|
||||
%2 : float = prim::Constant[value=3.1]()
|
||||
%3 : float = aten::add(%1, %2)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=8]()
|
||||
%2 : int = prim::Constant[value=7]()
|
||||
%3 : int = aten::add(%2, %1)
|
||||
%1 : int = prim::Constant[value=7]()
|
||||
%2 : int = prim::Constant[value=8]()
|
||||
%3 : int = aten::add(%1, %2)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=7]()
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Dynamic = aten::add(%x, %1, %2)
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%2 : int = prim::Constant[value=7]()
|
||||
%3 : Dynamic = aten::add(%x, %2, %1)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -11,15 +11,14 @@ ModelProto {
|
||||
nodes: [
|
||||
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Loop", inputs: [1,2,x.1], outputs: [3], attributes: [{ name: 'body', type: graph, value:
|
||||
Node {type: "Loop", inputs: [2,1,x.1], outputs: [3], attributes: [{ name: 'body', type: graph, value:
|
||||
GraphProto {
|
||||
name: "torch-jit-export1"
|
||||
inputs: [{name: "_", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "6", type:Tensor dims: }]
|
||||
outputs: [{name: "8", type:Tensor dims: },{name: "7", type:Tensor dims: }]
|
||||
outputs: [{name: "1", type:Tensor dims: },{name: "7", type:Tensor dims: }]
|
||||
initializers: []
|
||||
nodes: [
|
||||
Node {type: "Add", inputs: [6,6], outputs: [7], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [8], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
|
||||
Node {type: "Add", inputs: [6,6], outputs: [7], attributes: []}
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -12,9 +12,9 @@ ModelProto {
|
||||
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Shape", inputs: [x], outputs: [3], attributes: []},
|
||||
Node {type: "Gather", inputs: [3,2], outputs: [4], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "Gather", inputs: [3,1], outputs: [4], attributes: [{ name: 'axis', type: int, value: 0}]},
|
||||
Node {type: "Cast", inputs: [4], outputs: [5], attributes: [{ name: 'to', type: int, value: 1}]},
|
||||
Node {type: "Cast", inputs: [1], outputs: [6], attributes: [{ name: 'to', type: int, value: 1}]},
|
||||
Node {type: "Cast", inputs: [2], outputs: [6], attributes: [{ name: 'to', type: int, value: 1}]},
|
||||
Node {type: "Div", inputs: [5,6], outputs: [7], attributes: []},
|
||||
Node {type: "Add", inputs: [x,7], outputs: [8], attributes: []}
|
||||
]
|
||||
|
@ -11,14 +11,14 @@ ModelProto {
|
||||
nodes: [
|
||||
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "size", inputs: [x,2], outputs: [3], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "_cast_Float", inputs: [3,4], outputs: [5], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "_cast_Float", inputs: [1,6], outputs: [7], attributes: []},
|
||||
Node {type: "div", inputs: [5,7], outputs: [z], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [9], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "add", inputs: [x,z,9], outputs: [10], attributes: []}
|
||||
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "size", inputs: [x,2], outputs: [4], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "_cast_Float", inputs: [4,5], outputs: [6], attributes: []},
|
||||
Node {type: "Constant", inputs: [], outputs: [7], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
|
||||
Node {type: "_cast_Float", inputs: [3,7], outputs: [8], attributes: []},
|
||||
Node {type: "div", inputs: [6,8], outputs: [z], attributes: []},
|
||||
Node {type: "add", inputs: [x,z,1], outputs: [10], attributes: []}
|
||||
]
|
||||
}
|
||||
opset_import: [OperatorSetIdProto { domain: }],
|
||||
|
@ -1,7 +1,7 @@
|
||||
graph(%a : Dynamic) {
|
||||
%2 : int = prim::Constant[value=2]()
|
||||
%1 : string = prim::Constant[value="a\n\tb\n"]()
|
||||
%3 : string = prim::Constant[value="aa"]()
|
||||
%1 : string = prim::Constant[value="a\n\tb\n"]()
|
||||
%2 : int = prim::Constant[value=2]()
|
||||
= prim::Print(%a, %1, %2, %3)
|
||||
return (%a);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
graph(%x : Dynamic) {
|
||||
%1 : int = prim::Constant[value=4]()
|
||||
%2 : int[] = prim::ListConstruct(%1)
|
||||
%3 : bool = prim::Constant[value=0]()
|
||||
%4 : Dynamic = aten::sum(%x, %2, %3)
|
||||
%1 : bool = prim::Constant[value=0]()
|
||||
%2 : int = prim::Constant[value=4]()
|
||||
%3 : int[] = prim::ListConstruct(%2)
|
||||
%4 : Dynamic = aten::sum(%x, %3, %1)
|
||||
return (%4);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
graph(%x : Double(*, *, *, *, *)) {
|
||||
%1 : int = prim::Constant[value=4]()
|
||||
%2 : int[] = prim::ListConstruct(%1)
|
||||
%3 : bool = prim::Constant[value=0]()
|
||||
%4 : Dynamic = aten::sum(%x, %2, %3)
|
||||
%1 : bool = prim::Constant[value=0]()
|
||||
%2 : int = prim::Constant[value=4]()
|
||||
%3 : int[] = prim::ListConstruct(%2)
|
||||
%4 : Dynamic = aten::sum(%x, %3, %1)
|
||||
return (%4);
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
graph(%x : Float(*, *)
|
||||
%z : Float()) {
|
||||
%2 : int = prim::TensorToNum(%z)
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%y : Float(*, *) = aten::add(%x, %2, %3)
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : int = prim::TensorToNum(%z)
|
||||
%y : Float(*, *) = aten::add(%x, %3, %2)
|
||||
%5 : Float(*, *) = aten::mul(%x, %y)
|
||||
return (%5);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
graph() {
|
||||
%0 : int = prim::Constant[value=1]()
|
||||
%1 : float = prim::Constant[value=5]()
|
||||
%b : int = prim::FloatToInt(%1)
|
||||
%3 : int = aten::add(%b, %0)
|
||||
%0 : float = prim::Constant[value=5]()
|
||||
%1 : int = prim::Constant[value=1]()
|
||||
%b : int = prim::FloatToInt(%0)
|
||||
%3 : int = aten::add(%b, %1)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
graph() {
|
||||
%0 : float = prim::Constant[value=1]()
|
||||
%1 : int = prim::Constant[value=2]()
|
||||
%b : float = prim::IntToFloat(%1)
|
||||
%3 : float = aten::add(%b, %0)
|
||||
%0 : int = prim::Constant[value=2]()
|
||||
%1 : float = prim::Constant[value=1]()
|
||||
%b : float = prim::IntToFloat(%0)
|
||||
%3 : float = aten::add(%b, %1)
|
||||
return (%3);
|
||||
}
|
||||
|
@ -3075,6 +3075,33 @@ a")
|
||||
y2 = torch.sum(x, dim=0)
|
||||
self.assertEqual(y, y2)
|
||||
|
||||
def test_constant_pooling(self):
|
||||
def func(cond):
|
||||
a = 1
|
||||
b = 4
|
||||
c = 0
|
||||
d = "abc"
|
||||
e = "bcd"
|
||||
f = "abc"
|
||||
x = torch.ones([2])
|
||||
y = x * 4
|
||||
z = torch.ones([2])
|
||||
if bool(cond):
|
||||
c = b - a
|
||||
else:
|
||||
y = torch.rand(0)
|
||||
if bool(cond):
|
||||
y = torch.rand(1)
|
||||
print(d, e, f, x, y, z)
|
||||
b = b - a
|
||||
return a, b, c, x
|
||||
|
||||
self.checkScript(func, torch.tensor([1]))
|
||||
graph = torch.jit.script(func).graph
|
||||
self.run_pass('constant_propagation', graph)
|
||||
self.run_pass('constant_pooling', graph)
|
||||
self.assertExpectedGraph(graph)
|
||||
|
||||
def test_literal(self):
|
||||
def func1(a, b):
|
||||
c = a, b
|
||||
|
@ -146,6 +146,7 @@ set(TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/import.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/interpreter.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/constants.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/ir.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
|
||||
@ -153,6 +154,7 @@ set(TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/constant_propagation.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/constant_pooling.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/common_subexpression_elimination.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/create_autodiff_subgraphs.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/inline_autodiff_subgraphs.cpp
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "torch/csrc/jit/passes/annotate_effects.h"
|
||||
#include "torch/csrc/jit/passes/batch_mm.h"
|
||||
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
||||
#include "torch/csrc/jit/passes/constant_pooling.h"
|
||||
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
|
||||
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
||||
#include "torch/csrc/jit/passes/erase_number_types.h"
|
||||
@ -444,6 +445,7 @@ private:
|
||||
void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
|
||||
EliminateDeadCode(graph);
|
||||
EliminateCommonSubexpression(graph);
|
||||
ConstantPooling(graph);
|
||||
UnrollLoops(graph);
|
||||
PeepholeOptimize(graph);
|
||||
CheckInplace(graph);
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "torch/csrc/jit/passes/erase_number_types.h"
|
||||
#include "torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h"
|
||||
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
||||
#include "torch/csrc/jit/passes/constant_pooling.h"
|
||||
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
|
||||
#include "torch/csrc/jit/passes/peephole.h"
|
||||
#include "torch/csrc/jit/passes/canonicalize.h"
|
||||
@ -88,6 +89,7 @@ void initJITBindings(PyObject *module) {
|
||||
.def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
|
||||
return EliminateCommonSubexpression(g); // overload resolution
|
||||
})
|
||||
.def("_jit_pass_constant_pooling", ConstantPooling)
|
||||
.def("_jit_pass_peephole", PeepholeOptimize)
|
||||
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
|
||||
return Canonicalize(g);
|
||||
|
113
torch/csrc/jit/node_hashing.cpp
Normal file
113
torch/csrc/jit/node_hashing.cpp
Normal file
@ -0,0 +1,113 @@
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "torch/csrc/jit/assertions.h"
|
||||
#include "torch/csrc/jit/interned_strings.h"
|
||||
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
||||
#include "torch/csrc/jit/node_hashing.h"
|
||||
#include "torch/csrc/utils/functional.h"
|
||||
#include "torch/csrc/utils/hash.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
|
||||
return &lhs.type() == &rhs.type() && lhs.equal(rhs);
|
||||
}
|
||||
|
||||
bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
|
||||
if (lhs.size() != rhs.size()) return false;
|
||||
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
|
||||
}
|
||||
|
||||
|
||||
// Check whether two nodes have the same attributes in CSE.
|
||||
// This function may be too conservative for general use.
|
||||
// Do NOT support g/gs attributes.
|
||||
bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
|
||||
JIT_ASSERT(lhs != nullptr);
|
||||
JIT_ASSERT(rhs != nullptr);
|
||||
// One has attributes, the other does not.
|
||||
if (lhs->hasAttributes() != rhs->hasAttributes()) return false;
|
||||
// Neither has attributes.
|
||||
if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true;
|
||||
|
||||
auto lnames = lhs->attributeNames();
|
||||
auto rnames = rhs->attributeNames();
|
||||
std::sort(lnames.begin(), lnames.end());
|
||||
std::sort(rnames.begin(), rnames.end());
|
||||
if (lnames != rnames) return false;
|
||||
|
||||
for (auto name : lnames) {
|
||||
if (lhs->kindOf(name) != rhs->kindOf(name)) return false;
|
||||
|
||||
#define COMPARE_ATTRIBUTEVALUE(type) \
|
||||
case AttributeKind::type: \
|
||||
{ if (lhs->type(name) != rhs->type(name)) return false; } break;
|
||||
|
||||
switch(lhs->kindOf(name)) {
|
||||
COMPARE_ATTRIBUTEVALUE(f)
|
||||
COMPARE_ATTRIBUTEVALUE(fs)
|
||||
COMPARE_ATTRIBUTEVALUE(i)
|
||||
COMPARE_ATTRIBUTEVALUE(is)
|
||||
COMPARE_ATTRIBUTEVALUE(s)
|
||||
COMPARE_ATTRIBUTEVALUE(ss)
|
||||
case AttributeKind::t: {
|
||||
if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
|
||||
break;
|
||||
}
|
||||
case AttributeKind::ts: {
|
||||
if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
|
||||
break;
|
||||
}
|
||||
case AttributeKind::g:
|
||||
case AttributeKind::gs:
|
||||
return false;
|
||||
}
|
||||
|
||||
#undef COMPARE_ATTRIBUTEVALUE
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
size_t HashNode::operator()(const Node* k) const {
|
||||
JIT_ASSERT(k != nullptr);
|
||||
return get_hash(k->kind(),
|
||||
fmap(k->outputs(), [](const Value *v) { return v->type()->kind(); }),
|
||||
fmap(k->inputs(), [](const Value *v) { return v->unique(); }));
|
||||
};
|
||||
|
||||
bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
|
||||
if (lhs == nullptr && rhs == nullptr) return true;
|
||||
if (lhs == nullptr || rhs == nullptr) return false;
|
||||
|
||||
if (lhs->kind() != rhs->kind()) return false;
|
||||
|
||||
// Check whether the output types are the same.
|
||||
auto lhs_outputs = lhs->outputs();
|
||||
auto rhs_outputs = rhs->outputs();
|
||||
if (lhs_outputs.size() != rhs_outputs.size()) return false;
|
||||
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
|
||||
if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check whether the inputs are the same.
|
||||
auto lhs_inputs = lhs->inputs();
|
||||
auto rhs_inputs = rhs->inputs();
|
||||
if (lhs_inputs.size() != rhs_inputs.size()) return false;
|
||||
if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false;
|
||||
|
||||
if (!attributesEqualCSE(lhs, rhs)) return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
}}
|
15
torch/csrc/jit/node_hashing.h
Normal file
15
torch/csrc/jit/node_hashing.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
struct HashNode {
|
||||
size_t operator()(const Node* k) const;
|
||||
};
|
||||
|
||||
struct EqualNode {
|
||||
bool operator()(const Node* lhs, const Node* rhs) const;
|
||||
};
|
||||
|
||||
}}
|
@ -6,118 +6,17 @@
|
||||
#include "torch/csrc/jit/assertions.h"
|
||||
#include "torch/csrc/jit/interned_strings.h"
|
||||
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
||||
#include "torch/csrc/jit/node_hashing.h"
|
||||
#include "torch/csrc/utils/functional.h"
|
||||
#include "torch/csrc/utils/hash.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
|
||||
return &lhs.type() == &rhs.type() && lhs.equal(rhs);
|
||||
}
|
||||
|
||||
bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
|
||||
if (lhs.size() != rhs.size()) return false;
|
||||
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
|
||||
}
|
||||
|
||||
|
||||
// Check whether two nodes have the same attributes in CSE.
|
||||
// This function may be too conservative for general use.
|
||||
// Do NOT support t/ts/g/gs attributes.
|
||||
// If t/ts are supported, CONSTANT node comparison may need to consider device.
|
||||
bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
|
||||
JIT_ASSERT(lhs != nullptr);
|
||||
JIT_ASSERT(rhs != nullptr);
|
||||
// One has attributes, the other does not.
|
||||
if (lhs->hasAttributes() != rhs->hasAttributes()) return false;
|
||||
// Neither has attributes.
|
||||
if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true;
|
||||
|
||||
auto lnames = lhs->attributeNames();
|
||||
auto rnames = rhs->attributeNames();
|
||||
std::sort(lnames.begin(), lnames.end());
|
||||
std::sort(rnames.begin(), rnames.end());
|
||||
if (lnames != rnames) return false;
|
||||
|
||||
for (auto name : lnames) {
|
||||
if (lhs->kindOf(name) != rhs->kindOf(name)) return false;
|
||||
|
||||
#define COMPARE_ATTRIBUTEVALUE(type) \
|
||||
case AttributeKind::type: \
|
||||
{ if (lhs->type(name) != rhs->type(name)) return false; } break;
|
||||
|
||||
switch(lhs->kindOf(name)) {
|
||||
COMPARE_ATTRIBUTEVALUE(f)
|
||||
COMPARE_ATTRIBUTEVALUE(fs)
|
||||
COMPARE_ATTRIBUTEVALUE(i)
|
||||
COMPARE_ATTRIBUTEVALUE(is)
|
||||
COMPARE_ATTRIBUTEVALUE(s)
|
||||
COMPARE_ATTRIBUTEVALUE(ss)
|
||||
case AttributeKind::t: {
|
||||
if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
|
||||
break;
|
||||
}
|
||||
case AttributeKind::ts: {
|
||||
if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
|
||||
break;
|
||||
}
|
||||
case AttributeKind::g:
|
||||
case AttributeKind::gs:
|
||||
return false;
|
||||
}
|
||||
|
||||
#undef COMPARE_ATTRIBUTEVALUE
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct HashNodeCSE {
|
||||
size_t operator()(const Node* k) const {
|
||||
JIT_ASSERT(k != nullptr);
|
||||
return get_hash(k->kind(),
|
||||
fmap(k->outputs(), [](const Value *v) { return v->type()->kind(); }),
|
||||
fmap(k->inputs(), [](const Value *v) { return v->unique(); }));
|
||||
}
|
||||
};
|
||||
|
||||
struct EqualNodeCSE {
|
||||
bool operator()(const Node* lhs, const Node* rhs) const {
|
||||
if (lhs == nullptr && rhs == nullptr) return true;
|
||||
if (lhs == nullptr || rhs == nullptr) return false;
|
||||
|
||||
if (lhs->kind() != rhs->kind()) return false;
|
||||
|
||||
// Check whether the output types are the same.
|
||||
auto lhs_outputs = lhs->outputs();
|
||||
auto rhs_outputs = rhs->outputs();
|
||||
if (lhs_outputs.size() != rhs_outputs.size()) return false;
|
||||
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
|
||||
if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check whether the inputs are the same.
|
||||
auto lhs_inputs = lhs->inputs();
|
||||
auto rhs_inputs = rhs->inputs();
|
||||
if (lhs_inputs.size() != rhs_inputs.size()) return false;
|
||||
if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false;
|
||||
|
||||
if (!attributesEqualCSE(lhs, rhs)) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// The function implements common subexpression elimination.
|
||||
// Since the nodes are visited in topological order, one pass is enough.
|
||||
void EliminateCommonSubexpression(Block * block,
|
||||
std::function<Node*(Node*)> parent_lookup_fn) {
|
||||
std::unordered_set<Node*, HashNodeCSE, EqualNodeCSE> subexprs;
|
||||
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
|
||||
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
|
||||
auto node = *it;
|
||||
if (node->kind() == prim::PythonOp
|
||||
|
53
torch/csrc/jit/passes/constant_pooling.cpp
Normal file
53
torch/csrc/jit/passes/constant_pooling.cpp
Normal file
@ -0,0 +1,53 @@
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include <unordered_set>
|
||||
#include "torch/csrc/jit/interned_strings.h"
|
||||
#include "torch/csrc/jit/passes/constant_pooling.h"
|
||||
#include "torch/csrc/jit/node_hashing.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
//Very similar to the common subexpression elimination pass
|
||||
//Move all constants to the beginning of the graph, and deduplicate
|
||||
void ConstantPooling(Block * block, std::unordered_set<Node*, HashNode, EqualNode>& constants) {
|
||||
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
|
||||
auto node = *it;
|
||||
// node may be moved to a different block so advance iterator now
|
||||
++it;
|
||||
if (!node->blocks().empty()) {
|
||||
// Traverse sub-blocks.
|
||||
for (auto block : node->blocks()) {
|
||||
ConstantPooling(block, constants);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->kind() != prim::Constant) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto first_node = node->owningGraph()->block()->nodes().front();
|
||||
if (node != first_node)
|
||||
node->moveBefore(first_node);
|
||||
|
||||
// Check whether the same constant already exists.
|
||||
auto subit = constants.insert(node);
|
||||
if (!subit.second) {
|
||||
// constant exists, replace the uses of node, and destroy it.
|
||||
auto existing = *subit.first;
|
||||
node->replaceAllUsesWith(existing);
|
||||
node->destroy();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
void ConstantPooling(const std::shared_ptr<Graph>& graph) {
|
||||
std::unordered_set<Node*, HashNode, EqualNode> constants;
|
||||
ConstantPooling(graph->block(), constants);
|
||||
}
|
||||
|
||||
}}
|
9
torch/csrc/jit/passes/constant_pooling.h
Normal file
9
torch/csrc/jit/passes/constant_pooling.h
Normal file
@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
TORCH_API void ConstantPooling(const std::shared_ptr<Graph>& graph);
|
||||
|
||||
}}
|
@ -1,6 +1,7 @@
|
||||
#include "torch/csrc/jit/script/compiler.h"
|
||||
#include "torch/csrc/jit/passes/lower_tuples.h"
|
||||
#include "torch/csrc/jit/passes/annotate_effects.h"
|
||||
#include "torch/csrc/jit/passes/constant_pooling.h"
|
||||
#include "torch/csrc/jit/operator.h"
|
||||
#include "torch/csrc/jit/interpreter.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
@ -903,6 +904,7 @@ struct to_ir {
|
||||
AnnotateEffects(graph);
|
||||
// remove any uses of tuples that we inserted that are not needed
|
||||
LowerSimpleTuples(graph);
|
||||
ConstantPooling(graph);
|
||||
}
|
||||
|
||||
private:
|
||||
|
Reference in New Issue
Block a user