Add quant-dequant nodes for bias. (#20045)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20045

This pass adds quant-dequant nodes for bias. This pass requires
quant-dequant pass for activations and weights to be done as it is required
to compute the qparams for bias

Differential Revision: D15179141

fbshipit-source-id: 3aab9fceefcadc3fa42a4e802d9b1e18addad78a
This commit is contained in:
Nishant Pandit
2019-05-21 21:54:09 -07:00
committed by Facebook Github Bot
parent c2d0e7316f
commit a501e7d5be
3 changed files with 133 additions and 6 deletions

View File

@ -1456,16 +1456,61 @@ graph(%x : Tensor,
# to insert quant-dequant nodes for quantizable tensors. The type analysis
# happens as part of this jit pass
torch._C._jit_pass_constant_propagation(scriptModule.graph)
torch._C._jit_pass_insert_quantdequant_for_param(scriptModule._c,
"forward",
"weight",
getQParamFunc)
torch._C._jit_pass_insert_quantdequant_for_weight_bias(scriptModule._c,
"forward",
"weight",
getQParamFunc)
# We expect to see quant-dequant node before conv node for weight.
FileCheck().check("quantize_linear").check_next("int_repr") \
.check_next("dequantize_linear") \
.check("conv2d").run(str(scriptModule.graph))
def test_insert_quantdequant_for_bias(self):
# Inserting quant-dequant nodes for bias requires scale info present for
# activation and weight so q-dq pass done first for these inputs.
class testModule(torch.jit.ScriptModule):
def __init__(self):
super(testModule, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1, 1).float()
@torch.jit.script_method
def forward(self, x):
x = x.quantize_linear(1.0, 0, torch.uint8)
x = x.int_repr()
x = x.dequantize_linear(1.0, 0, torch.uint8)
x = self.conv1(x)
return x
def getQParamFuncW(value):
return 'per_tensor_quant', 0.5, 1
def getQParamFunc(input_scale, weight_scale):
scale = 1 / input_scale / weight_scale
zero_point = 0
return 'per_tensor_quant', scale, zero_point
scriptModule = testModule()
torch._C._jit_pass_constant_propagation(scriptModule.graph)
torch._C._jit_pass_insert_quantdequant_for_weight_bias(scriptModule._c,
"forward",
"weight",
getQParamFuncW)
torch._C._jit_pass_insert_quantdequant_for_weight_bias(scriptModule._c,
"forward",
"bias",
getQParamFunc)
# We expect to see 3 pairs of quant-dequant nodes.
FileCheck().check("quantize_linear").check_next("int_repr") \
.check_next("dequantize_linear").check("quantize_linear") \
.check_next("int_repr").check_next("dequantize_linear") \
.check("quantize_linear").check_next("int_repr") \
.check_next("dequantize_linear").check("conv2d") \
.run(str(scriptModule.graph))
def test_pattern_based_rewrite(self):
# mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
# --> mulmul(mulmul(x,y,z), x, y)

View File

@ -170,11 +170,13 @@ void initJITBindings(PyObject* module) {
return InsertQuantDequantNodes(g, qparam_dict);
})
.def(
"_jit_pass_insert_quantdequant_for_param",
"_jit_pass_insert_quantdequant_for_weight_bias",
[](std::shared_ptr<script::Module>& moduleObj,
const std::string& method_name,
const std::string& param_name,
py::function pyGetQParamFunc) {
// For different static params we pass different getQParamFunc via
// same interface exposed by the quantizer.
if (param_name == std::string("weight")) {
auto getQParamFunc =
py::cast<std::function<std::tuple<std::string, float, int>(
@ -185,6 +187,18 @@ void initJITBindings(PyObject* module) {
param_name,
getQParamFunc,
at::ScalarType::QInt8);
} else if (param_name == std::string("bias")) {
auto getQParamFunc =
py::cast<std::function<std::tuple<std::string, float, int>(
float, float)>>(pyGetQParamFunc);
InsertQuantDequantNodesForParam(
moduleObj,
method_name,
param_name,
getQParamFunc,
at::ScalarType::QInt32);
} else {
TORCH_CHECK(false, "Invalid Param Name");
}
})
.def(

View File

@ -30,6 +30,15 @@ int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor"};
return quantnodeLookup.find(n);
}
Value* getScaleValue(Node* n) {
if (n->kind().toQualString() != std::string("aten::dequantize_linear")) {
return nullptr;
}
TORCH_CHECK(n->inputs().size() == 4);
// Fetch scale from the dequant node
return n->inputs()[1];
}
Node* traverseToQuantNode(Node* dq) {
TORCH_INTERNAL_ASSERT(dq != nullptr);
TORCH_INTERNAL_ASSERT(dq->inputs().size() != 0);
@ -488,8 +497,59 @@ void InsertQuantDequantNodesForParam(
}
}
void InsertQuantDequantNodesForParam(
script::Method& method,
const std::string& param_name,
const std::function<std::tuple<std::string, float, int>(float, float)>&
getQParamFunc,
at::ScalarType t) {
TORCH_CHECK(getQParamFunc != nullptr);
auto params_to_insert_qdq = getQuantizableParamsofName(method, param_name);
for (param_info_t& param_info : params_to_insert_qdq) {
// This getQParamFunc requires scale for weight and activation because for
// quantized ops that involve matmul with weight and bias(WX+b), input scale
// for bias is computed from input activation and weight. if weight attr
// not present we skip inserting q-dq node.
Node* n = param_info.n;
// Check if this node has weight attr as input
size_t param_index = getParamIndexinOpArgs(n, std::string("weight"));
if (param_index >= n->inputs().size()) {
// No attribute by name weight
continue;
}
std::vector<size_t> node_inputs_idx{0, param_index};
std::array<float, 2> scale_factors = {0, 0};
bool skip_node = false;
for (size_t idx = 0; idx < node_inputs_idx.size(); idx++) {
size_t input_index = node_inputs_idx[idx];
Value* input_value = n->inputs()[input_index];
Node* n_input_value = input_value->node();
Value* scale_value = getScaleValue(n_input_value);
if (!scale_value) {
// Dequant node pattern for input is missing
skip_node = true;
break;
}
c10::IValue scale_ivalue = toIValue(scale_value).value();
float input_scale = static_cast<float>(scale_ivalue.toDouble());
TORCH_CHECK(input_scale != 0.0);
scale_factors[idx] = input_scale;
}
if (skip_node) {
continue;
}
auto bias_qparam = getQParamFunc(scale_factors[0], scale_factors[1]);
Node* dq = addQuantDeQuantNodesFor(
param_info.v, param_info.v->node()->next(), bias_qparam, t);
TORCH_INTERNAL_ASSERT(dq != nullptr);
param_info.n->replaceInputWith(param_info.v, dq->output());
}
}
// Exposing the template api helps reuse the same interface for different
// qparamfunc for different qschemes.
// qparamfunc for different qschemes and params.
template <typename Fn>
void InsertQuantDequantNodesForParam(
std::shared_ptr<script::Module>& moduleObj,
@ -510,5 +570,13 @@ template TORCH_API void InsertQuantDequantNodesForParam(
getQParamFunc,
at::ScalarType t);
template TORCH_API void InsertQuantDequantNodesForParam(
std::shared_ptr<script::Module>& moduleObj,
const std::string& method_name,
const std::string& param_name,
const std::function<std::tuple<std::string, float, int>(float, float)>&
getQParamFunc,
at::ScalarType t);
} // namespace jit
} // namespace torch