mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add quant-dequant nodes for bias. (#20045)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20045 This pass adds quant-dequant nodes for bias. This pass requires quant-dequant pass for activations and weights to be done as it is required to compute the qparams for bias Differential Revision: D15179141 fbshipit-source-id: 3aab9fceefcadc3fa42a4e802d9b1e18addad78a
This commit is contained in:
committed by
Facebook Github Bot
parent
c2d0e7316f
commit
a501e7d5be
@ -1456,16 +1456,61 @@ graph(%x : Tensor,
|
||||
# to insert quant-dequant nodes for quantizable tensors. The type analysis
|
||||
# happens as part of this jit pass
|
||||
torch._C._jit_pass_constant_propagation(scriptModule.graph)
|
||||
torch._C._jit_pass_insert_quantdequant_for_param(scriptModule._c,
|
||||
"forward",
|
||||
"weight",
|
||||
getQParamFunc)
|
||||
torch._C._jit_pass_insert_quantdequant_for_weight_bias(scriptModule._c,
|
||||
"forward",
|
||||
"weight",
|
||||
getQParamFunc)
|
||||
|
||||
# We expect to see quant-dequant node before conv node for weight.
|
||||
FileCheck().check("quantize_linear").check_next("int_repr") \
|
||||
.check_next("dequantize_linear") \
|
||||
.check("conv2d").run(str(scriptModule.graph))
|
||||
|
||||
def test_insert_quantdequant_for_bias(self):
|
||||
# Inserting quant-dequant nodes for bias requires scale info present for
|
||||
# activation and weight so q-dq pass done first for these inputs.
|
||||
|
||||
class testModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(testModule, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1, 1).float()
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
x = x.quantize_linear(1.0, 0, torch.uint8)
|
||||
x = x.int_repr()
|
||||
x = x.dequantize_linear(1.0, 0, torch.uint8)
|
||||
x = self.conv1(x)
|
||||
return x
|
||||
|
||||
def getQParamFuncW(value):
|
||||
return 'per_tensor_quant', 0.5, 1
|
||||
|
||||
def getQParamFunc(input_scale, weight_scale):
|
||||
scale = 1 / input_scale / weight_scale
|
||||
zero_point = 0
|
||||
return 'per_tensor_quant', scale, zero_point
|
||||
|
||||
scriptModule = testModule()
|
||||
|
||||
torch._C._jit_pass_constant_propagation(scriptModule.graph)
|
||||
torch._C._jit_pass_insert_quantdequant_for_weight_bias(scriptModule._c,
|
||||
"forward",
|
||||
"weight",
|
||||
getQParamFuncW)
|
||||
torch._C._jit_pass_insert_quantdequant_for_weight_bias(scriptModule._c,
|
||||
"forward",
|
||||
"bias",
|
||||
getQParamFunc)
|
||||
# We expect to see 3 pairs of quant-dequant nodes.
|
||||
|
||||
FileCheck().check("quantize_linear").check_next("int_repr") \
|
||||
.check_next("dequantize_linear").check("quantize_linear") \
|
||||
.check_next("int_repr").check_next("dequantize_linear") \
|
||||
.check("quantize_linear").check_next("int_repr") \
|
||||
.check_next("dequantize_linear").check("conv2d") \
|
||||
.run(str(scriptModule.graph))
|
||||
|
||||
def test_pattern_based_rewrite(self):
|
||||
# mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
|
||||
# --> mulmul(mulmul(x,y,z), x, y)
|
||||
|
@ -170,11 +170,13 @@ void initJITBindings(PyObject* module) {
|
||||
return InsertQuantDequantNodes(g, qparam_dict);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_insert_quantdequant_for_param",
|
||||
"_jit_pass_insert_quantdequant_for_weight_bias",
|
||||
[](std::shared_ptr<script::Module>& moduleObj,
|
||||
const std::string& method_name,
|
||||
const std::string& param_name,
|
||||
py::function pyGetQParamFunc) {
|
||||
// For different static params we pass different getQParamFunc via
|
||||
// same interface exposed by the quantizer.
|
||||
if (param_name == std::string("weight")) {
|
||||
auto getQParamFunc =
|
||||
py::cast<std::function<std::tuple<std::string, float, int>(
|
||||
@ -185,6 +187,18 @@ void initJITBindings(PyObject* module) {
|
||||
param_name,
|
||||
getQParamFunc,
|
||||
at::ScalarType::QInt8);
|
||||
} else if (param_name == std::string("bias")) {
|
||||
auto getQParamFunc =
|
||||
py::cast<std::function<std::tuple<std::string, float, int>(
|
||||
float, float)>>(pyGetQParamFunc);
|
||||
InsertQuantDequantNodesForParam(
|
||||
moduleObj,
|
||||
method_name,
|
||||
param_name,
|
||||
getQParamFunc,
|
||||
at::ScalarType::QInt32);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid Param Name");
|
||||
}
|
||||
})
|
||||
.def(
|
||||
|
@ -30,6 +30,15 @@ int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor"};
|
||||
return quantnodeLookup.find(n);
|
||||
}
|
||||
|
||||
Value* getScaleValue(Node* n) {
|
||||
if (n->kind().toQualString() != std::string("aten::dequantize_linear")) {
|
||||
return nullptr;
|
||||
}
|
||||
TORCH_CHECK(n->inputs().size() == 4);
|
||||
// Fetch scale from the dequant node
|
||||
return n->inputs()[1];
|
||||
}
|
||||
|
||||
Node* traverseToQuantNode(Node* dq) {
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(dq->inputs().size() != 0);
|
||||
@ -488,8 +497,59 @@ void InsertQuantDequantNodesForParam(
|
||||
}
|
||||
}
|
||||
|
||||
void InsertQuantDequantNodesForParam(
|
||||
script::Method& method,
|
||||
const std::string& param_name,
|
||||
const std::function<std::tuple<std::string, float, int>(float, float)>&
|
||||
getQParamFunc,
|
||||
at::ScalarType t) {
|
||||
TORCH_CHECK(getQParamFunc != nullptr);
|
||||
auto params_to_insert_qdq = getQuantizableParamsofName(method, param_name);
|
||||
|
||||
for (param_info_t& param_info : params_to_insert_qdq) {
|
||||
// This getQParamFunc requires scale for weight and activation because for
|
||||
// quantized ops that involve matmul with weight and bias(WX+b), input scale
|
||||
// for bias is computed from input activation and weight. if weight attr
|
||||
// not present we skip inserting q-dq node.
|
||||
Node* n = param_info.n;
|
||||
// Check if this node has weight attr as input
|
||||
size_t param_index = getParamIndexinOpArgs(n, std::string("weight"));
|
||||
if (param_index >= n->inputs().size()) {
|
||||
// No attribute by name weight
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<size_t> node_inputs_idx{0, param_index};
|
||||
std::array<float, 2> scale_factors = {0, 0};
|
||||
bool skip_node = false;
|
||||
for (size_t idx = 0; idx < node_inputs_idx.size(); idx++) {
|
||||
size_t input_index = node_inputs_idx[idx];
|
||||
Value* input_value = n->inputs()[input_index];
|
||||
Node* n_input_value = input_value->node();
|
||||
Value* scale_value = getScaleValue(n_input_value);
|
||||
if (!scale_value) {
|
||||
// Dequant node pattern for input is missing
|
||||
skip_node = true;
|
||||
break;
|
||||
}
|
||||
c10::IValue scale_ivalue = toIValue(scale_value).value();
|
||||
float input_scale = static_cast<float>(scale_ivalue.toDouble());
|
||||
TORCH_CHECK(input_scale != 0.0);
|
||||
scale_factors[idx] = input_scale;
|
||||
}
|
||||
if (skip_node) {
|
||||
continue;
|
||||
}
|
||||
auto bias_qparam = getQParamFunc(scale_factors[0], scale_factors[1]);
|
||||
Node* dq = addQuantDeQuantNodesFor(
|
||||
param_info.v, param_info.v->node()->next(), bias_qparam, t);
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
param_info.n->replaceInputWith(param_info.v, dq->output());
|
||||
}
|
||||
}
|
||||
|
||||
// Exposing the template api helps reuse the same interface for different
|
||||
// qparamfunc for different qschemes.
|
||||
// qparamfunc for different qschemes and params.
|
||||
template <typename Fn>
|
||||
void InsertQuantDequantNodesForParam(
|
||||
std::shared_ptr<script::Module>& moduleObj,
|
||||
@ -510,5 +570,13 @@ template TORCH_API void InsertQuantDequantNodesForParam(
|
||||
getQParamFunc,
|
||||
at::ScalarType t);
|
||||
|
||||
template TORCH_API void InsertQuantDequantNodesForParam(
|
||||
std::shared_ptr<script::Module>& moduleObj,
|
||||
const std::string& method_name,
|
||||
const std::string& param_name,
|
||||
const std::function<std::tuple<std::string, float, int>(float, float)>&
|
||||
getQParamFunc,
|
||||
at::ScalarType t);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user