Unify the addQuantDequantNode api for inputs and outputs from quant nodes (#20677)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20677

With new changes in IR, it is possible to insert nodes after param
nodes in graph. Thus we do not need to have two methods for inserting q-dq
nodes to input or output to quantizable nodes.

Differential Revision: D15406354

fbshipit-source-id: 1963762f434fd82877fa76a272e8520c342b6069
This commit is contained in:
Nishant Pandit
2019-05-20 15:24:36 -07:00
committed by Facebook Github Bot
parent cf548ba683
commit be33434d85

View File

@ -30,6 +30,15 @@ int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor"};
return quantnodeLookup.find(n);
}
Node* traverseToQuantNode(Node* dq) {
TORCH_INTERNAL_ASSERT(dq != nullptr);
TORCH_INTERNAL_ASSERT(dq->inputs().size() != 0);
Node* intrepr = dq->inputs()[0]->node();
TORCH_INTERNAL_ASSERT(intrepr != nullptr);
TORCH_INTERNAL_ASSERT(intrepr->inputs().size() != 0);
return intrepr->inputs()[0]->node();
}
// Look for index of particular param in op schema
size_t getParamIndexinOpArgs(Node* n, const std::string& param_name) {
TORCH_INTERNAL_ASSERT(n != nullptr);
@ -114,51 +123,47 @@ Value* insertScalarType(Node* ins_node, at::ScalarType t) {
}
// Create Quant Node
Node* createQuantNode(Value* v, Node* n) {
Node* quant = n->owningGraph()->create(
at::Symbol::fromQualString("aten::quantize_linear"));
Node* createQuantNode(Value* v, Graph* g) {
Node* quant = g->create(at::Symbol::fromQualString("aten::quantize_linear"));
TORCH_INTERNAL_ASSERT(quant != nullptr, "Failed to create quant node");
quant->output()->setUniqueName(v->uniqueName() + ".quant");
quant->setScope(n->scope());
return quant;
}
// Create Dequant node
Node* createDeQuantNode(Value* v, Node* n) {
Node* dequant = n->owningGraph()->create(
at::Symbol::fromQualString("aten::dequantize_linear"));
Node* createDeQuantNode(Value* v, Graph* g) {
Node* dequant =
g->create(at::Symbol::fromQualString("aten::dequantize_linear"));
TORCH_INTERNAL_ASSERT(dequant != nullptr, "Failed to create dequant node");
dequant->output()->setUniqueName(v->uniqueName() + ".dequant");
dequant->setScope(n->scope());
return dequant;
}
// Create IntTensor Node
Node* createIntReprNode(Value* v, Node* n) {
Node* intrepr =
n->owningGraph()->create(at::Symbol::fromQualString("aten::int_repr"));
Node* createIntReprNode(Value* v, Graph* g) {
Node* intrepr = g->create(at::Symbol::fromQualString("aten::int_repr"));
TORCH_INTERNAL_ASSERT(intrepr != nullptr, "Failed to create inttensor node");
intrepr->output()->setUniqueName(v->uniqueName() + ".intrepr");
intrepr->setScope(n->scope());
return intrepr;
}
// Insert Quant-Dequant node pattern for quantizable node outputs
void addQuantDeQuantNodes(
Node* addQuantDeQuantNodesFor(
Value* v,
Node* insert_point,
const std::tuple<std::string, float, int>& qparam,
at::ScalarType t = at::ScalarType::Undefined) {
TORCH_INTERNAL_ASSERT(v != nullptr);
Node* n = v->node();
Node* quant = createQuantNode(v, n);
Node* intrepr = createIntReprNode(v, n);
Node* dequant = createDeQuantNode(v, n);
WithCurrentScope scope_guard(
*insert_point->owningGraph(), insert_point->scope());
Node* quant = createQuantNode(v, insert_point->owningGraph());
Node* intrepr = createIntReprNode(v, insert_point->owningGraph());
Node* dequant = createDeQuantNode(v, insert_point->owningGraph());
// Add quant-intrepr-dequant nodes and replace for all uses of Value
quant->insertAfter(n);
quant->insertAfter(insert_point);
intrepr->insertAfter(quant);
dequant->insertAfter(intrepr);
v->replaceAllUsesWith(dequant->output());
// Attach inputs to quantization pattern nodes
quant->addInput(v);
@ -178,44 +183,7 @@ void addQuantDeQuantNodes(
quant->addInput(scalartype_v);
dequant->addInput(scalartype_v);
}
}
// Insert Quant-Dequant node pattern for specific input to node n
void addQuantDeQuantNodesForInput(
Value* v,
Node* n,
const std::tuple<std::string, float, int>& qparam,
at::ScalarType t = at::ScalarType::Undefined) {
TORCH_INTERNAL_ASSERT(v != nullptr);
TORCH_INTERNAL_ASSERT(n != nullptr);
Node* quant = createQuantNode(v, n);
Node* intrepr = createIntReprNode(v, n);
Node* dequant = createDeQuantNode(v, n);
// Insert the quant-intrepr-dequant node for the V->N
// pair which is identified as quantizable during
// graph iteration
dequant->insertBefore(n);
intrepr->insertBefore(dequant);
quant->insertBefore(intrepr);
n->replaceInputWith(v, dequant->output());
// Attach inputs to quantization pattern nodes
quant->addInput(v);
intrepr->addInput(quant->output());
dequant->addInput(intrepr->output());
// Insert qparam nodes
auto qparam_values = insertQuantParamNodes(quant, qparam);
for (Value* qparam_value : qparam_values) {
quant->addInput(qparam_value);
dequant->addInput(qparam_value);
}
if (t != at::ScalarType::Undefined) {
Value* scalartype_v = insertScalarType(quant, t);
TORCH_INTERNAL_ASSERT(scalartype_v != nullptr);
quant->addInput(scalartype_v);
dequant->addInput(scalartype_v);
}
return dequant;
}
template <typename... ArgT>
@ -373,7 +341,7 @@ void InsertQuantDequantNodes(
blocks_to_visit.push(graph->block());
// For storing quantizable values - node pairs that are external
// or intermediate inputs to quantizable nodes
std::vector<std::pair<Value*, Node*>> quantInputs;
std::vector<param_info_t> quantInputs;
// For storing quantizable values that are output of quantizable nodes
// Since same value can go to multiple nodes, we use set so that
// we insert quant-dequant node pairs for value only once
@ -440,7 +408,7 @@ void InsertQuantDequantNodes(
// N1 is not quantizable node but N4 and N7 are
// quantizable nodes. So we add the (V1, N4) and
// (V2, N7) as insertion points for quant-dequant nodes
quantInputs.emplace_back(v, n);
quantInputs.emplace_back(param_info_t{v, n, 0});
}
}
} // End Loop for nodes within block
@ -463,17 +431,27 @@ void InsertQuantDequantNodes(
}
// Insert the quant-dequant pair for values output from quantizable nodes
for (auto& ele : quantOutputs) {
if (qparam_value_dict.count(ele) != 0) {
addQuantDeQuantNodes(ele, qparam_value_dict[ele]);
for (auto& v_to_quant : quantOutputs) {
if (qparam_value_dict.count(v_to_quant) != 0) {
Node* dq = addQuantDeQuantNodesFor(
v_to_quant, v_to_quant->node(), qparam_value_dict[v_to_quant]);
TORCH_INTERNAL_ASSERT(dq != nullptr);
v_to_quant->replaceAllUsesWith(dq->output());
// Above step replaces v->quant with vdq->quant. We need to restore link.
// Below chain traverse up from dq to q node.
Node* q = traverseToQuantNode(dq);
TORCH_INTERNAL_ASSERT(q != nullptr);
q->replaceInputWith(dq->output(), v_to_quant);
}
}
// Insert the quant-dequant pair for values inputs to quantizable nodes
for (auto& ele : quantInputs) {
if (qparam_value_dict.count(ele.first) != 0) {
addQuantDeQuantNodesForInput(
ele.first, ele.second, qparam_value_dict[ele.first]);
for (auto& param_info : quantInputs) {
if (qparam_value_dict.count(param_info.v) != 0) {
Node* dq = addQuantDeQuantNodesFor(
param_info.v, param_info.v->node(), qparam_value_dict[param_info.v]);
TORCH_INTERNAL_ASSERT(dq != nullptr);
param_info.n->replaceInputWith(param_info.v, dq->output());
}
}
}
@ -500,7 +478,10 @@ void InsertQuantDequantNodesForParam(
const auto& itensor = param_slot.value();
at::Tensor tensor_var = itensor.toTensor().detach();
auto qparam = getQParamFunc(tensor_var);
addQuantDeQuantNodesForInput(param_info.v, param_info.n, qparam, t);
Node* dq = addQuantDeQuantNodesFor(
param_info.v, param_info.v->node()->next(), qparam, t);
TORCH_INTERNAL_ASSERT(dq != nullptr);
param_info.n->replaceInputWith(param_info.v, dq->output());
}
}