mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant] Fix fusion pattern for add_relu (#39367)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39367 We shouldn't match `%alpha` argument since it could be used by multiple functions Test Plan: Imported from OSS Differential Revision: D21829295 fbshipit-source-id: 6daa320a4b56df4e142b8e02e04a3ecb36284d1b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3001facd7a
commit
625f4e39a7
@ -241,8 +241,7 @@ graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %pad
|
||||
return (%r_quant) )";
|
||||
|
||||
std::string add_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
|
||||
@ -251,8 +250,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
return (%r) )";
|
||||
|
||||
std::string add_inplace_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
|
||||
@ -261,8 +259,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_add_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
|
||||
@ -271,8 +268,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_add_inplace_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
|
||||
@ -281,7 +277,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_add_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
|
||||
%r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point)
|
||||
return (%r) )";
|
||||
|
||||
|
Reference in New Issue
Block a user