mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][graphmode] Different rule for add/add_/mul/mul_ (#38667)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38667 Test Plan: Imported from OSS Differential Revision: D21633555 fbshipit-source-id: 03b0298e83bf4dbda41b048c0edc7bb92cd4e1df
This commit is contained in:
committed by
Facebook GitHub Bot
parent
57d6e19d6f
commit
a8d8fc5532
@ -245,12 +245,32 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
|
||||
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
|
||||
%r_relu = aten::relu(%r_add)
|
||||
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
std::string add_inplace_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
|
||||
%r_relu = aten::relu_(%r_add)
|
||||
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_add_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
|
||||
%r_relu = aten::relu(%r_add)
|
||||
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_add_inplace_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%alpha = prim::Constant[value=1]()
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
@ -350,7 +370,7 @@ graph(%a_quant, %b_scalar, %alpha):
|
||||
};
|
||||
|
||||
// quantized::add_scalar_out
|
||||
std::string add_scalar_out = R"(
|
||||
std::string inplace_add_scalar = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r = aten::add_(%a_dequant, %b_scalar, %alpha)
|
||||
@ -369,19 +389,33 @@ graph(%a_quant, %b_scalar, %alpha):
|
||||
%r = aten::relu(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string add_scalar_inplace_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_add = aten::add(%a_dequant, %b_scalar, %alpha)
|
||||
%r = aten::relu_(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_add_scalar_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%r = quantized::add_scalar_relu(%a_quant, %b_scalar)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::add_scalar_relu_out
|
||||
std::string add_scalar_relu_out = R"(
|
||||
std::string inplace_add_scalar_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
|
||||
%r = aten::relu(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_add_scalar_inplace_relu = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_add = aten::add_(%a_dequant, %b_scalar, %alpha)
|
||||
%r = aten::relu_(%r_add)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_add_scalar_relu_out = R"(
|
||||
graph(%a_quant, %b_scalar, %alpha):
|
||||
%r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
|
||||
@ -450,7 +484,7 @@ graph(%a_quant, %b_scalar):
|
||||
%r = aten::mul(%a_dequant, %b_scalar)
|
||||
return (%r) )";
|
||||
|
||||
std::string mul_scalar_out = R"(
|
||||
std::string inplace_mul_scalar = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r = aten::mul_(%a_dequant, %b_scalar)
|
||||
@ -485,15 +519,6 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_mul_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_mul = aten::mul_(%a_dequant, %b_dequant)
|
||||
%r_relu = aten::relu(%r_mul)
|
||||
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
std::string mul_inplace_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
@ -503,6 +528,15 @@ graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_mul_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%b_dequant = aten::dequantize(%b_quant)
|
||||
%r_mul = aten::mul_(%a_dequant, %b_dequant)
|
||||
%r_relu = aten::relu(%r_mul)
|
||||
%r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_mul_inplace_relu = R"(
|
||||
graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
@ -525,19 +559,33 @@ graph(%a_quant, %b_scalar):
|
||||
%r = aten::relu(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string mul_scalar_inplace_relu = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_mul = aten::mul(%a_dequant, %b_scalar)
|
||||
%r = aten::relu_(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_mul_scalar_relu = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::mul_scalar_relu(%a_quant, %b_scalar)
|
||||
return (%r) )";
|
||||
|
||||
// quantized::mul_scalar_relu_out
|
||||
std::string mul_scalar_relu_out = R"(
|
||||
std::string inplace_mul_scalar_relu = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_mul = aten::mul_(%a_dequant, %b_scalar)
|
||||
%r = aten::relu(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string inplace_mul_scalar_inplace_relu = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%a_dequant = aten::dequantize(%a_quant)
|
||||
%r_mul = aten::mul_(%a_dequant, %b_scalar)
|
||||
%r = aten::relu_(%r_mul)
|
||||
return (%r) )";
|
||||
|
||||
std::string quantized_mul_scalar_relu_out = R"(
|
||||
graph(%a_quant, %b_scalar):
|
||||
%r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
|
||||
@ -687,15 +735,26 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
|
||||
{"quantized::linear", linear, quantized_linear},
|
||||
{"quantized::add_relu", add_relu, quantized_add_relu, add_filter},
|
||||
{"quantized::add_relu", add_inplace_relu, quantized_add_relu, add_filter},
|
||||
{"quantized::add", add, quantized_add, add_filter},
|
||||
{"quantized::add", inplace_add, quantized_add, add_filter},
|
||||
{"quantized::add_relu", inplace_add_relu, quantized_add_relu, add_filter},
|
||||
{"quantized::add_relu",
|
||||
inplace_add_inplace_relu,
|
||||
quantized_add_relu,
|
||||
add_filter},
|
||||
// note that this must come before quantized::add_scalar
|
||||
{"quantized::add_scalar_relu",
|
||||
add_scalar_relu,
|
||||
quantized_add_scalar_relu,
|
||||
add_scalar_filter},
|
||||
{"quantized::add_scalar_relu",
|
||||
add_scalar_inplace_relu,
|
||||
quantized_add_scalar_relu,
|
||||
add_scalar_filter},
|
||||
{"quantized::add_scalar_relu_out",
|
||||
add_scalar_relu_out,
|
||||
inplace_add_scalar_relu,
|
||||
quantized_add_scalar_relu_out,
|
||||
add_scalar_filter},
|
||||
{"quantized::add_scalar_relu_out",
|
||||
inplace_add_scalar_inplace_relu,
|
||||
quantized_add_scalar_relu_out,
|
||||
add_scalar_filter},
|
||||
{"quantized::add_scalar",
|
||||
@ -703,9 +762,11 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
|
||||
quantized_add_scalar,
|
||||
add_scalar_filter},
|
||||
{"quantized::add_scalar_out",
|
||||
add_scalar_out,
|
||||
inplace_add_scalar,
|
||||
quantized_add_scalar_out,
|
||||
add_scalar_filter},
|
||||
{"quantized::add", add, quantized_add, add_filter},
|
||||
{"quantized::add", inplace_add, quantized_add, add_filter},
|
||||
{"quantized::cat", cat, quantized_cat},
|
||||
{"quantized::batch_norm2d", batch_norm2d, quantized_batch_norm2d},
|
||||
{"quantized::batch_norm2d_relu",
|
||||
@ -714,14 +775,20 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
|
||||
{"quantized::batch_norm2d_relu",
|
||||
batch_norm2d_inplace_relu,
|
||||
quantized_batch_norm2d_relu},
|
||||
{"quantized::mul", mul, quantized_mul},
|
||||
{"quantized::mul", inplace_mul, quantized_mul},
|
||||
{"quantized::mul_scalar_relu",
|
||||
mul_scalar_relu,
|
||||
quantized_mul_scalar_relu,
|
||||
mul_scalar_filter},
|
||||
{"quantized::mul_scalar_relu",
|
||||
mul_scalar_inplace_relu,
|
||||
quantized_mul_scalar_relu,
|
||||
mul_scalar_filter},
|
||||
{"quantized::mul_scalar_relu_out",
|
||||
mul_scalar_relu_out,
|
||||
inplace_mul_scalar_relu,
|
||||
quantized_mul_scalar_relu_out,
|
||||
mul_scalar_filter},
|
||||
{"quantized::mul_scalar_relu_out",
|
||||
inplace_mul_scalar_inplace_relu,
|
||||
quantized_mul_scalar_relu_out,
|
||||
mul_scalar_filter},
|
||||
{"quantized::mul_scalar",
|
||||
@ -729,13 +796,15 @@ graph(%a_quant, %normalized_shape, %weight, %bias, %eps, %cudnn_enabled, %output
|
||||
quantized_mul_scalar,
|
||||
mul_scalar_filter},
|
||||
{"quantized::mul_scalar",
|
||||
mul_scalar_out,
|
||||
inplace_mul_scalar,
|
||||
quantized_mul_scalar_out,
|
||||
mul_scalar_filter},
|
||||
{"quantized::mul_relu", mul_relu, quantized_mul_relu},
|
||||
{"quantized::mul_relu", mul_inplace_relu, quantized_mul_relu},
|
||||
{"quantized::mul_relu", inplace_mul_relu, quantized_mul_relu},
|
||||
{"quantized::mul_relu", inplace_mul_inplace_relu, quantized_mul_relu},
|
||||
{"quantized::mul", mul, quantized_mul},
|
||||
{"quantized::mul", inplace_mul, quantized_mul},
|
||||
{"quantized::hardswish", hardswish, quantized_hardswish},
|
||||
{"quantized::layer_norm", layer_norm, quantized_layer_norm},
|
||||
avg_pool1d,
|
||||
|
Reference in New Issue
Block a user