[quant] Fix fusion pattern for add_relu (#39367)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39367

We shouldn't match `%alpha` argument since it could be used by multiple functions

Test Plan: Imported from OSS

Differential Revision: D21829295

fbshipit-source-id: 6daa320a4b56df4e142b8e02e04a3ecb36284d1b
This commit is contained in:
Jerry Zhang
2020-06-01 20:13:09 -07:00
committed by Facebook GitHub Bot
parent 3001facd7a
commit 625f4e39a7
2 changed files with 38 additions and 13 deletions

View File

@ -241,8 +241,7 @@ graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %pad
return (%r_quant) )";
std::string add_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
@ -251,8 +250,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
return (%r) )";
std::string add_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
@ -261,8 +259,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
return (%r) )";
std::string inplace_add_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
@ -271,8 +268,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
return (%r) )";
std::string inplace_add_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
@ -281,7 +277,7 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
return (%r) )";
std::string quantized_add_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
%r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point)
return (%r) )";