mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Unify the addQuantDequantNode api for inputs and outputs from quant nodes (#20677)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20677 With new changes in IR, it is possible to insert nodes after param nodes in graph. Thus we do not need to have two methods for inserting q-dq nodes to input or output to quantizable nodes. Differential Revision: D15406354 fbshipit-source-id: 1963762f434fd82877fa76a272e8520c342b6069
This commit is contained in:
committed by
Facebook Github Bot
parent
cf548ba683
commit
be33434d85
@ -30,6 +30,15 @@ int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor"};
|
||||
return quantnodeLookup.find(n);
|
||||
}
|
||||
|
||||
Node* traverseToQuantNode(Node* dq) {
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(dq->inputs().size() != 0);
|
||||
Node* intrepr = dq->inputs()[0]->node();
|
||||
TORCH_INTERNAL_ASSERT(intrepr != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(intrepr->inputs().size() != 0);
|
||||
return intrepr->inputs()[0]->node();
|
||||
}
|
||||
|
||||
// Look for index of particular param in op schema
|
||||
size_t getParamIndexinOpArgs(Node* n, const std::string& param_name) {
|
||||
TORCH_INTERNAL_ASSERT(n != nullptr);
|
||||
@ -114,51 +123,47 @@ Value* insertScalarType(Node* ins_node, at::ScalarType t) {
|
||||
}
|
||||
|
||||
// Create Quant Node
|
||||
Node* createQuantNode(Value* v, Node* n) {
|
||||
Node* quant = n->owningGraph()->create(
|
||||
at::Symbol::fromQualString("aten::quantize_linear"));
|
||||
Node* createQuantNode(Value* v, Graph* g) {
|
||||
Node* quant = g->create(at::Symbol::fromQualString("aten::quantize_linear"));
|
||||
TORCH_INTERNAL_ASSERT(quant != nullptr, "Failed to create quant node");
|
||||
quant->output()->setUniqueName(v->uniqueName() + ".quant");
|
||||
quant->setScope(n->scope());
|
||||
return quant;
|
||||
}
|
||||
|
||||
// Create Dequant node
|
||||
Node* createDeQuantNode(Value* v, Node* n) {
|
||||
Node* dequant = n->owningGraph()->create(
|
||||
at::Symbol::fromQualString("aten::dequantize_linear"));
|
||||
Node* createDeQuantNode(Value* v, Graph* g) {
|
||||
Node* dequant =
|
||||
g->create(at::Symbol::fromQualString("aten::dequantize_linear"));
|
||||
TORCH_INTERNAL_ASSERT(dequant != nullptr, "Failed to create dequant node");
|
||||
dequant->output()->setUniqueName(v->uniqueName() + ".dequant");
|
||||
dequant->setScope(n->scope());
|
||||
return dequant;
|
||||
}
|
||||
|
||||
// Create IntTensor Node
|
||||
Node* createIntReprNode(Value* v, Node* n) {
|
||||
Node* intrepr =
|
||||
n->owningGraph()->create(at::Symbol::fromQualString("aten::int_repr"));
|
||||
Node* createIntReprNode(Value* v, Graph* g) {
|
||||
Node* intrepr = g->create(at::Symbol::fromQualString("aten::int_repr"));
|
||||
TORCH_INTERNAL_ASSERT(intrepr != nullptr, "Failed to create inttensor node");
|
||||
intrepr->output()->setUniqueName(v->uniqueName() + ".intrepr");
|
||||
intrepr->setScope(n->scope());
|
||||
return intrepr;
|
||||
}
|
||||
|
||||
// Insert Quant-Dequant node pattern for quantizable node outputs
|
||||
void addQuantDeQuantNodes(
|
||||
Node* addQuantDeQuantNodesFor(
|
||||
Value* v,
|
||||
Node* insert_point,
|
||||
const std::tuple<std::string, float, int>& qparam,
|
||||
at::ScalarType t = at::ScalarType::Undefined) {
|
||||
TORCH_INTERNAL_ASSERT(v != nullptr);
|
||||
Node* n = v->node();
|
||||
Node* quant = createQuantNode(v, n);
|
||||
Node* intrepr = createIntReprNode(v, n);
|
||||
Node* dequant = createDeQuantNode(v, n);
|
||||
WithCurrentScope scope_guard(
|
||||
*insert_point->owningGraph(), insert_point->scope());
|
||||
Node* quant = createQuantNode(v, insert_point->owningGraph());
|
||||
Node* intrepr = createIntReprNode(v, insert_point->owningGraph());
|
||||
Node* dequant = createDeQuantNode(v, insert_point->owningGraph());
|
||||
|
||||
// Add quant-intrepr-dequant nodes and replace for all uses of Value
|
||||
quant->insertAfter(n);
|
||||
quant->insertAfter(insert_point);
|
||||
intrepr->insertAfter(quant);
|
||||
dequant->insertAfter(intrepr);
|
||||
v->replaceAllUsesWith(dequant->output());
|
||||
|
||||
// Attach inputs to quantization pattern nodes
|
||||
quant->addInput(v);
|
||||
@ -178,44 +183,7 @@ void addQuantDeQuantNodes(
|
||||
quant->addInput(scalartype_v);
|
||||
dequant->addInput(scalartype_v);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert Quant-Dequant node pattern for specific input to node n
|
||||
void addQuantDeQuantNodesForInput(
|
||||
Value* v,
|
||||
Node* n,
|
||||
const std::tuple<std::string, float, int>& qparam,
|
||||
at::ScalarType t = at::ScalarType::Undefined) {
|
||||
TORCH_INTERNAL_ASSERT(v != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(n != nullptr);
|
||||
Node* quant = createQuantNode(v, n);
|
||||
Node* intrepr = createIntReprNode(v, n);
|
||||
Node* dequant = createDeQuantNode(v, n);
|
||||
|
||||
// Insert the quant-intrepr-dequant node for the V->N
|
||||
// pair which is identified as quantizable during
|
||||
// graph iteration
|
||||
dequant->insertBefore(n);
|
||||
intrepr->insertBefore(dequant);
|
||||
quant->insertBefore(intrepr);
|
||||
n->replaceInputWith(v, dequant->output());
|
||||
|
||||
// Attach inputs to quantization pattern nodes
|
||||
quant->addInput(v);
|
||||
intrepr->addInput(quant->output());
|
||||
dequant->addInput(intrepr->output());
|
||||
// Insert qparam nodes
|
||||
auto qparam_values = insertQuantParamNodes(quant, qparam);
|
||||
for (Value* qparam_value : qparam_values) {
|
||||
quant->addInput(qparam_value);
|
||||
dequant->addInput(qparam_value);
|
||||
}
|
||||
if (t != at::ScalarType::Undefined) {
|
||||
Value* scalartype_v = insertScalarType(quant, t);
|
||||
TORCH_INTERNAL_ASSERT(scalartype_v != nullptr);
|
||||
quant->addInput(scalartype_v);
|
||||
dequant->addInput(scalartype_v);
|
||||
}
|
||||
return dequant;
|
||||
}
|
||||
|
||||
template <typename... ArgT>
|
||||
@ -373,7 +341,7 @@ void InsertQuantDequantNodes(
|
||||
blocks_to_visit.push(graph->block());
|
||||
// For storing quantizable values - node pairs that are external
|
||||
// or intermediate inputs to quantizable nodes
|
||||
std::vector<std::pair<Value*, Node*>> quantInputs;
|
||||
std::vector<param_info_t> quantInputs;
|
||||
// For storing quantizable values that are output of quantizable nodes
|
||||
// Since same value can go to multiple nodes, we use set so that
|
||||
// we insert quant-dequant node pairs for value only once
|
||||
@ -440,7 +408,7 @@ void InsertQuantDequantNodes(
|
||||
// N1 is not quantizable node but N4 and N7 are
|
||||
// quantizable nodes. So we add the (V1, N4) and
|
||||
// (V2, N7) as insertion points for quant-dequant nodes
|
||||
quantInputs.emplace_back(v, n);
|
||||
quantInputs.emplace_back(param_info_t{v, n, 0});
|
||||
}
|
||||
}
|
||||
} // End Loop for nodes within block
|
||||
@ -463,17 +431,27 @@ void InsertQuantDequantNodes(
|
||||
}
|
||||
|
||||
// Insert the quant-dequant pair for values output from quantizable nodes
|
||||
for (auto& ele : quantOutputs) {
|
||||
if (qparam_value_dict.count(ele) != 0) {
|
||||
addQuantDeQuantNodes(ele, qparam_value_dict[ele]);
|
||||
for (auto& v_to_quant : quantOutputs) {
|
||||
if (qparam_value_dict.count(v_to_quant) != 0) {
|
||||
Node* dq = addQuantDeQuantNodesFor(
|
||||
v_to_quant, v_to_quant->node(), qparam_value_dict[v_to_quant]);
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
v_to_quant->replaceAllUsesWith(dq->output());
|
||||
// Above step replaces v->quant with vdq->quant. We need to restore link.
|
||||
// Below chain traverse up from dq to q node.
|
||||
Node* q = traverseToQuantNode(dq);
|
||||
TORCH_INTERNAL_ASSERT(q != nullptr);
|
||||
q->replaceInputWith(dq->output(), v_to_quant);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the quant-dequant pair for values inputs to quantizable nodes
|
||||
for (auto& ele : quantInputs) {
|
||||
if (qparam_value_dict.count(ele.first) != 0) {
|
||||
addQuantDeQuantNodesForInput(
|
||||
ele.first, ele.second, qparam_value_dict[ele.first]);
|
||||
for (auto& param_info : quantInputs) {
|
||||
if (qparam_value_dict.count(param_info.v) != 0) {
|
||||
Node* dq = addQuantDeQuantNodesFor(
|
||||
param_info.v, param_info.v->node(), qparam_value_dict[param_info.v]);
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
param_info.n->replaceInputWith(param_info.v, dq->output());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -500,7 +478,10 @@ void InsertQuantDequantNodesForParam(
|
||||
const auto& itensor = param_slot.value();
|
||||
at::Tensor tensor_var = itensor.toTensor().detach();
|
||||
auto qparam = getQParamFunc(tensor_var);
|
||||
addQuantDeQuantNodesForInput(param_info.v, param_info.n, qparam, t);
|
||||
Node* dq = addQuantDeQuantNodesFor(
|
||||
param_info.v, param_info.v->node()->next(), qparam, t);
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
param_info.n->replaceInputWith(param_info.v, dq->output());
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user