Add mandatory ScalarType nodes as input to the quant-dequant nodes. (#20468)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20468

ScalarType node is mandatory for activations and parameters now.
This change inserts ScalarType node for all the quant-dequant nodes. For the activations, currently the default value is at::ScalarType::Undefined. Remove this and explicitly pass the at::ScalarType::QUint8 dtype

Differential Revision: D15331600

fbshipit-source-id: 5b51e0b42e694bf409026af4783a12da6d7e234b
This commit is contained in:
Nishant Pandit
2019-05-20 19:52:47 -07:00
committed by Facebook Github Bot
parent 371cf109a3
commit d73caca2a1
2 changed files with 15 additions and 12 deletions

View File

@ -184,7 +184,7 @@ void initJITBindings(PyObject* module) {
method_name,
param_name,
getQParamFunc,
at::ScalarType::QUInt8);
at::ScalarType::QInt8);
}
})
.def(

View File

@ -152,7 +152,7 @@ Node* addQuantDeQuantNodesFor(
Value* v,
Node* insert_point,
const std::tuple<std::string, float, int>& qparam,
at::ScalarType t = at::ScalarType::Undefined) {
at::ScalarType t) {
TORCH_INTERNAL_ASSERT(v != nullptr);
WithCurrentScope scope_guard(
*insert_point->owningGraph(), insert_point->scope());
@ -175,14 +175,11 @@ Node* addQuantDeQuantNodesFor(
quant->addInput(qparam_value);
dequant->addInput(qparam_value);
}
// optional argument required only for quantization
// of specific attributes eg: bias.
if (t != at::ScalarType::Undefined) {
Value* scalartype_v = insertScalarType(quant, t);
TORCH_INTERNAL_ASSERT(scalartype_v != nullptr);
quant->addInput(scalartype_v);
dequant->addInput(scalartype_v);
}
// Add ScalarType Node for q-dq
Value* scalartype_v = insertScalarType(quant, t);
TORCH_INTERNAL_ASSERT(scalartype_v != nullptr);
quant->addInput(scalartype_v);
dequant->addInput(scalartype_v);
return dequant;
}
@ -434,7 +431,10 @@ void InsertQuantDequantNodes(
for (auto& v_to_quant : quantOutputs) {
if (qparam_value_dict.count(v_to_quant) != 0) {
Node* dq = addQuantDeQuantNodesFor(
v_to_quant, v_to_quant->node(), qparam_value_dict[v_to_quant]);
v_to_quant,
v_to_quant->node(),
qparam_value_dict[v_to_quant],
at::ScalarType::QUInt8);
TORCH_INTERNAL_ASSERT(dq != nullptr);
v_to_quant->replaceAllUsesWith(dq->output());
// Above step replaces v->quant with vdq->quant. We need to restore link.
@ -449,7 +449,10 @@ void InsertQuantDequantNodes(
for (auto& param_info : quantInputs) {
if (qparam_value_dict.count(param_info.v) != 0) {
Node* dq = addQuantDeQuantNodesFor(
param_info.v, param_info.v->node(), qparam_value_dict[param_info.v]);
param_info.v,
param_info.v->node(),
qparam_value_dict[param_info.v],
at::ScalarType::QUInt8);
TORCH_INTERNAL_ASSERT(dq != nullptr);
param_info.n->replaceInputWith(param_info.v, dq->output());
}