Files
pytorch/torch/csrc/utils/tensor_qschemes.cpp
Dmytro Dzhulgakov ebc2365fd3 Serialization for per channel qtensor (#26339)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26339

Serializes per-channel tensor in both torch.serialization and jit. Since we didn't bind Quantizer properly yet, I chose to save a tuple representing quantizer settings. To avoid recursive tensor serialization calls, I'm using tuple instead of tensor to store scales and zero points.

driazati - please check the serialization logic. Is there a good test that compares that JIT serialization and python serialization are equivalent? (I haven't tested it yet)

Test Plan: Imported from OSS

Differential Revision: D17443222

Pulled By: dzhulgakov

fbshipit-source-id: a34758de1ffd2ec1cdc5355f5baf95284a4ccf4b
2019-09-23 13:28:11 -07:00

45 lines
1.2 KiB
C++

#include <torch/csrc/utils/tensor_qschemes.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/QScheme.h>
#include <c10/core/QScheme.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
namespace torch {
namespace utils {
static PyObject* thp_qscheme_array[at::COMPILE_TIME_NUM_QSCHEMES];
void initializeQSchemes() {
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
if (!torch_module) {
throw python_error();
}
for (int i = 0; i < at::COMPILE_TIME_NUM_QSCHEMES; ++i) {
auto qscheme = static_cast<at::QScheme>(i);
PyObject* qscheme_obj = THPQScheme_New(qscheme, toString(qscheme));
thp_qscheme_array[static_cast<int>(qscheme)] = qscheme_obj;
Py_INCREF(qscheme_obj);
if (PyModule_AddObject(
torch_module, toString(qscheme).c_str(), qscheme_obj) != 0) {
throw python_error();
}
}
}
PyObject* getTHPQScheme(at::QScheme qscheme) {
auto qscheme_ = thp_qscheme_array[static_cast<int>(qscheme)];
if (!qscheme_) {
throw std::invalid_argument("unsupported QScheme");
}
return qscheme_;
}
} // namespace utils
} // namespace torch