[quant][graphmode] Different rule for add/add_/mul/mul_ (#38667)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38667

Test Plan: Imported from OSS

Differential Revision: D21633555

fbshipit-source-id: 03b0298e83bf4dbda41b048c0edc7bb92cd4e1df
This commit is contained in:
Jerry Zhang
2020-05-20 19:42:02 -07:00
committed by Facebook GitHub Bot
parent 57d6e19d6f
commit a8d8fc5532
9 changed files with 819 additions and 367 deletions

View File

@ -245,12 +245,32 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
%r_relu = aten::relu(%r_add)
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
return (%r) )";
std::string add_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
%r_relu = aten::relu_(%r_add)
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
return (%r) )";
std::string inplace_add_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
%r_relu = aten::relu(%r_add)
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
return (%r) )";
std::string inplace_add_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%alpha = prim::Constant[value=1]()
%a_dequant = aten::dequantize(%a_quant)
@ -350,7 +370,7 @@ graph(%a_quant, %b_scalar, %alpha):
};
// quantized::add_scalar_out
std::string add_scalar_out = R"(
std::string inplace_add_scalar = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r = aten::add_(%a_dequant, %b_scalar, %alpha)
@ -369,19 +389,33 @@ graph(%a_quant, %b_scalar, %alpha):
%r = aten::relu(%r_add)
return (%r) )";
std::string add_scalar_inplace_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r_add = aten::add(%a_dequant, %b_scalar, %alpha)
%r = aten::relu_(%r_add)
return (%r) )";
std::string quantized_add_scalar_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%r = quantized::add_scalar_relu(%a_quant, %b_scalar)
return (%r) )";
// quantized::add_scalar_relu_out
std::string add_scalar_relu_out = R"(
std::string inplace_add_scalar_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
%r = aten::relu(%r_add)
return (%r) )";
std::string inplace_add_scalar_inplace_relu = R"(
graph(%a_quant, %b_scalar, %alpha):
%a_dequant = aten::dequantize(%a_quant)
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
%r = aten::relu_(%r_add)
return (%r) )";
std::string quantized_add_scalar_relu_out = R"(
graph(%a_quant, %b_scalar, %alpha):
%r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
@ -450,7 +484,7 @@ graph(%a_quant, %b_scalar):
%r = aten::mul(%a_dequant, %b_scalar)
return (%r) )";
std::string mul_scalar_out = R"(
std::string inplace_mul_scalar = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r = aten::mul_(%a_dequant, %b_scalar)
@ -485,15 +519,6 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
return (%r) )";
std::string inplace_mul_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_mul = aten::mul_(%a_dequant, %b_dequant)
%r_relu = aten::relu(%r_mul)
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
return (%r) )";
std::string mul_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
@ -503,6 +528,15 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
return (%r) )";
std::string inplace_mul_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
%b_dequant = aten::dequantize(%b_quant)
%r_mul = aten::mul_(%a_dequant, %b_dequant)
%r_relu = aten::relu(%r_mul)
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
return (%r) )";
std::string inplace_mul_inplace_relu = R"(
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
%a_dequant = aten::dequantize(%a_quant)
@ -525,19 +559,33 @@ graph(%a_quant, %b_scalar):
%r = aten::relu(%r_mul)
return (%r) )";
std::string mul_scalar_inplace_relu = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r_mul = aten::mul(%a_dequant, %b_scalar)
%r = aten::relu_(%r_mul)
return (%r) )";
std::string quantized_mul_scalar_relu = R"(
graph(%a_quant, %b_scalar):
%r = quantized::mul_scalar_relu(%a_quant, %b_scalar)
return (%r) )";
// quantized::mul_scalar_relu_out
std::string mul_scalar_relu_out = R"(
std::string inplace_mul_scalar_relu = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r_mul = aten::mul_(%a_dequant, %b_scalar)
%r = aten::relu(%r_mul)
return (%r) )";
std::string inplace_mul_scalar_inplace_relu = R"(
graph(%a_quant, %b_scalar):
%a_dequant = aten::dequantize(%a_quant)
%r_mul = aten::mul_(%a_dequant, %b_scalar)
%r = aten::relu_(%r_mul)
return (%r) )";
std::string quantized_mul_scalar_relu_out = R"(
graph(%a_quant, %b_scalar):
%r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
@ -687,15 +735,26 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
{"quantized::linear", linear, quantized_linear},
{"quantized::add_relu", add_relu, quantized_add_relu, add_filter},
{"quantized::add_relu", add_inplace_relu, quantized_add_relu, add_filter},
{"quantized::add", add, quantized_add, add_filter},
{"quantized::add", inplace_add, quantized_add, add_filter},
{"quantized::add_relu", inplace_add_relu, quantized_add_relu, add_filter},
{"quantized::add_relu",
inplace_add_inplace_relu,
quantized_add_relu,
add_filter},
// note that this must come before quantized::add_scalar
{"quantized::add_scalar_relu",
add_scalar_relu,
quantized_add_scalar_relu,
add_scalar_filter},
{"quantized::add_scalar_relu",
add_scalar_inplace_relu,
quantized_add_scalar_relu,
add_scalar_filter},
{"quantized::add_scalar_relu_out",
add_scalar_relu_out,
inplace_add_scalar_relu,
quantized_add_scalar_relu_out,
add_scalar_filter},
{"quantized::add_scalar_relu_out",
inplace_add_scalar_inplace_relu,
quantized_add_scalar_relu_out,
add_scalar_filter},
{"quantized::add_scalar",
@ -703,9 +762,11 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
quantized_add_scalar,
add_scalar_filter},
{"quantized::add_scalar_out",
add_scalar_out,
inplace_add_scalar,
quantized_add_scalar_out,
add_scalar_filter},
{"quantized::add", add, quantized_add, add_filter},
{"quantized::add", inplace_add, quantized_add, add_filter},
{"quantized::cat", cat, quantized_cat},
{"quantized::batch_norm2d", batch_norm2d, quantized_batch_norm2d},
{"quantized::batch_norm2d_relu",
@ -714,14 +775,20 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
{"quantized::batch_norm2d_relu",
batch_norm2d_inplace_relu,
quantized_batch_norm2d_relu},
{"quantized::mul", mul, quantized_mul},
{"quantized::mul", inplace_mul, quantized_mul},
{"quantized::mul_scalar_relu",
mul_scalar_relu,
quantized_mul_scalar_relu,
mul_scalar_filter},
{"quantized::mul_scalar_relu",
mul_scalar_inplace_relu,
quantized_mul_scalar_relu,
mul_scalar_filter},
{"quantized::mul_scalar_relu_out",
mul_scalar_relu_out,
inplace_mul_scalar_relu,
quantized_mul_scalar_relu_out,
mul_scalar_filter},
{"quantized::mul_scalar_relu_out",
inplace_mul_scalar_inplace_relu,
quantized_mul_scalar_relu_out,
mul_scalar_filter},
{"quantized::mul_scalar",
@ -729,13 +796,15 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
quantized_mul_scalar,
mul_scalar_filter},
{"quantized::mul_scalar",
mul_scalar_out,
inplace_mul_scalar,
quantized_mul_scalar_out,
mul_scalar_filter},
{"quantized::mul_relu", mul_relu, quantized_mul_relu},
{"quantized::mul_relu", mul_inplace_relu, quantized_mul_relu},
{"quantized::mul_relu", inplace_mul_relu, quantized_mul_relu},
{"quantized::mul_relu", inplace_mul_inplace_relu, quantized_mul_relu},
{"quantized::mul", mul, quantized_mul},
{"quantized::mul", inplace_mul, quantized_mul},
{"quantized::hardswish", hardswish, quantized_hardswish},
{"quantized::layer_norm", layer_norm, quantized_layer_norm},
avg_pool1d,