mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add mandatory ScalarType nodes as input to the quant-dequant nodes. (#20468)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20468 ScalarType node is mandatory for activations and parameters now. This change inserts ScalarType node for all the quant-dequant nodes. For the activations, currently the default value is at::ScalarType::Undefined. Remove this and explicitly pass the at::ScalarType::QUint8 dtype Differential Revision: D15331600 fbshipit-source-id: 5b51e0b42e694bf409026af4783a12da6d7e234b
This commit is contained in:
committed by
Facebook Github Bot
parent
371cf109a3
commit
d73caca2a1
@ -184,7 +184,7 @@ void initJITBindings(PyObject* module) {
|
||||
method_name,
|
||||
param_name,
|
||||
getQParamFunc,
|
||||
at::ScalarType::QUInt8);
|
||||
at::ScalarType::QInt8);
|
||||
}
|
||||
})
|
||||
.def(
|
||||
|
@ -152,7 +152,7 @@ Node* addQuantDeQuantNodesFor(
|
||||
Value* v,
|
||||
Node* insert_point,
|
||||
const std::tuple<std::string, float, int>& qparam,
|
||||
at::ScalarType t = at::ScalarType::Undefined) {
|
||||
at::ScalarType t) {
|
||||
TORCH_INTERNAL_ASSERT(v != nullptr);
|
||||
WithCurrentScope scope_guard(
|
||||
*insert_point->owningGraph(), insert_point->scope());
|
||||
@ -175,14 +175,11 @@ Node* addQuantDeQuantNodesFor(
|
||||
quant->addInput(qparam_value);
|
||||
dequant->addInput(qparam_value);
|
||||
}
|
||||
// optional argument required only for quantization
|
||||
// of specific attributes eg: bias.
|
||||
if (t != at::ScalarType::Undefined) {
|
||||
Value* scalartype_v = insertScalarType(quant, t);
|
||||
TORCH_INTERNAL_ASSERT(scalartype_v != nullptr);
|
||||
quant->addInput(scalartype_v);
|
||||
dequant->addInput(scalartype_v);
|
||||
}
|
||||
// Add ScalarType Node for q-dq
|
||||
Value* scalartype_v = insertScalarType(quant, t);
|
||||
TORCH_INTERNAL_ASSERT(scalartype_v != nullptr);
|
||||
quant->addInput(scalartype_v);
|
||||
dequant->addInput(scalartype_v);
|
||||
return dequant;
|
||||
}
|
||||
|
||||
@ -434,7 +431,10 @@ void InsertQuantDequantNodes(
|
||||
for (auto& v_to_quant : quantOutputs) {
|
||||
if (qparam_value_dict.count(v_to_quant) != 0) {
|
||||
Node* dq = addQuantDeQuantNodesFor(
|
||||
v_to_quant, v_to_quant->node(), qparam_value_dict[v_to_quant]);
|
||||
v_to_quant,
|
||||
v_to_quant->node(),
|
||||
qparam_value_dict[v_to_quant],
|
||||
at::ScalarType::QUInt8);
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
v_to_quant->replaceAllUsesWith(dq->output());
|
||||
// Above step replaces v->quant with vdq->quant. We need to restore link.
|
||||
@ -449,7 +449,10 @@ void InsertQuantDequantNodes(
|
||||
for (auto& param_info : quantInputs) {
|
||||
if (qparam_value_dict.count(param_info.v) != 0) {
|
||||
Node* dq = addQuantDeQuantNodesFor(
|
||||
param_info.v, param_info.v->node(), qparam_value_dict[param_info.v]);
|
||||
param_info.v,
|
||||
param_info.v->node(),
|
||||
qparam_value_dict[param_info.v],
|
||||
at::ScalarType::QUInt8);
|
||||
TORCH_INTERNAL_ASSERT(dq != nullptr);
|
||||
param_info.n->replaceInputWith(param_info.v, dq->output());
|
||||
}
|
||||
|
Reference in New Issue
Block a user