Add qint8 type (int8_t) (#19984)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19984

Add qint8 for QTensor, with underlying type of int8_t

Reviewed By: jianyuh

Differential Revision: D15150715

fbshipit-source-id: 57580f599d46f9323af5ce462dbbc464b25e40d7
This commit is contained in:
Jerry Zhang
2019-05-17 20:29:33 -07:00
committed by Facebook Github Bot
parent 986c9eb537
commit 85fad0597c
13 changed files with 50 additions and 15 deletions

View File

@ -39,6 +39,9 @@ static DLDataType getDLDataType(const Tensor& t) {
case ScalarType::Bool: case ScalarType::Bool:
dtype.code = DLDataTypeCode::kDLUInt; dtype.code = DLDataTypeCode::kDLUInt;
break; break;
case ScalarType::QInt8:
throw std::logic_error("QInt8 is not supported by dlpack");
break;
case ScalarType::QUInt8: case ScalarType::QUInt8:
throw std::logic_error("QUInt8 is not supported by dlpack"); throw std::logic_error("QUInt8 is not supported by dlpack");
break; break;

View File

@ -211,6 +211,8 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
[&] { \ [&] { \
switch (TYPE) { \ switch (TYPE) { \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::QInt8, qint8, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \ AT_PRIVATE_CASE_TYPE( \
at::ScalarType::QUInt8, quint8, __VA_ARGS__) \ at::ScalarType::QUInt8, quint8, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \ AT_PRIVATE_CASE_TYPE( \

View File

@ -188,6 +188,7 @@ extension_backends = ['MSNPU', 'XLA']
# scalar_name, c_type, accreal, is_floating_type # scalar_name, c_type, accreal, is_floating_type
quantized_scalar_types = [ quantized_scalar_types = [
('QInt8', 'qint8', 'QInt8AccrealNotDefined', 'QInt8IsFloatingTypeNotDefined'),
('QUInt8', 'quint8', 'QUInt8AccrealNotDefined', 'QUInt8IsFloatingTypeNotDefined'), ('QUInt8', 'quint8', 'QUInt8AccrealNotDefined', 'QUInt8IsFloatingTypeNotDefined'),
('QInt32', 'qint32', 'QInt32AccrealNotDefined', 'Qint32IsFloatingTypeNotDefined'), ('QInt32', 'qint32', 'QInt32AccrealNotDefined', 'Qint32IsFloatingTypeNotDefined'),
] ]

View File

@ -8,9 +8,6 @@
namespace at { namespace at {
namespace native { namespace native {
Tensor& _s_copy__quantized(Tensor& self, const Tensor& src, bool /* unused */) { Tensor& _s_copy__quantized(Tensor& self, const Tensor& src, bool /* unused */) {
TORCH_CHECK(
self.scalar_type() == at::kQUInt8,
"Quantized copy only works with kQUInt8 as target Tensor");
TORCH_CHECK( TORCH_CHECK(
src.scalar_type() == at::kFloat, src.scalar_type() == at::kFloat,
"Quantized copy only works with kFloat as source Tensor"); "Quantized copy only works with kFloat as source Tensor");

View File

@ -18,6 +18,7 @@ type_map = {
'Bool', 'Bool',
], ],
'quantized': [ 'quantized': [
'QInt8',
'QUInt8', 'QUInt8',
'QInt32', 'QInt32',
] ]

View File

@ -152,11 +152,13 @@ Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, float scale, int32_t ze
return rtensor; return rtensor;
} }
#endif #endif
template CAFFE2_API qint8 quantize_val<qint8>(float scale, int32_t zero_point, float value);
template CAFFE2_API quint8 quantize_val<quint8>(float scale, int32_t zero_point, float value); template CAFFE2_API quint8 quantize_val<quint8>(float scale, int32_t zero_point, float value);
template CAFFE2_API qint32 quantize_val<qint32>(float scale, int32_t zero_point, float value); template CAFFE2_API qint32 quantize_val<qint32>(float scale, int32_t zero_point, float value);
template CAFFE2_API Tensor quantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
template CAFFE2_API Tensor quantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point); template CAFFE2_API Tensor quantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
template CAFFE2_API Tensor quantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point); template CAFFE2_API Tensor quantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
template CAFFE2_API Tensor dequantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
template CAFFE2_API Tensor dequantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point); template CAFFE2_API Tensor dequantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
template CAFFE2_API Tensor dequantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point); template CAFFE2_API Tensor dequantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);

View File

@ -29,8 +29,9 @@ namespace c10 {
_(std::complex<float>, ComplexFloat, z) /* 9 */ \ _(std::complex<float>, ComplexFloat, z) /* 9 */ \
_(std::complex<double>, ComplexDouble, z) /* 10 */ \ _(std::complex<double>, ComplexDouble, z) /* 10 */ \
_(bool, Bool, i) /* 11 */ \ _(bool, Bool, i) /* 11 */ \
_(c10::quint8, QUInt8, i) /* 12 */ \ _(c10::qint8, QInt8, i) /* 12 */ \
_(c10::qint32, QInt32, i) /* 13 */ _(c10::quint8, QUInt8, i) /* 13 */ \
_(c10::qint32, QInt32, i) /* 14 */
// If you want to support ComplexHalf for real, replace occurrences // If you want to support ComplexHalf for real, replace occurrences
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
@ -47,6 +48,7 @@ namespace c10 {
_(std::complex<float>, ComplexFloat, z) \ _(std::complex<float>, ComplexFloat, z) \
_(std::complex<double>, ComplexDouble, z) \ _(std::complex<double>, ComplexDouble, z) \
_(bool, Bool, i) \ _(bool, Bool, i) \
_(c10::qint8, QInt8, i) \
_(c10::quint8, QUInt8, i) \ _(c10::quint8, QUInt8, i) \
_(c10::qint32, QInt32, i) _(c10::qint32, QInt32, i)
@ -72,6 +74,7 @@ namespace c10 {
_(at::Half, Half, d) \ _(at::Half, Half, d) \
_(float, Float, d) \ _(float, Float, d) \
_(double, Double, d) \ _(double, Double, d) \
_(c10::qint8, QInt8, i) \
_(c10::quint8, QUInt8, i) \ _(c10::quint8, QUInt8, i) \
_(c10::qint32, QInt32, i) _(c10::qint32, QInt32, i)
@ -104,6 +107,7 @@ namespace c10 {
_(int64_t, Long, i) \ _(int64_t, Long, i) \
_(float, Float, d) \ _(float, Float, d) \
_(double, Double, d) \ _(double, Double, d) \
_(c10::qint8, QInt8, i) \
_(c10::quint8, QUInt8, i) \ _(c10::quint8, QUInt8, i) \
_(c10::qint32, QInt32, i) _(c10::qint32, QInt32, i)
@ -227,7 +231,7 @@ static inline bool isComplexType(ScalarType t) {
static inline bool isQIntType(ScalarType t) { static inline bool isQIntType(ScalarType t) {
// Don't forget to extend this when adding new QInt types // Don't forget to extend this when adding new QInt types
return t == ScalarType::QUInt8 || t == ScalarType::QInt32; return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
} }
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {

17
c10/util/qint8.h Normal file
View File

@ -0,0 +1,17 @@
#pragma once
#include <cstdint>
namespace c10 {
/**
* This is the data type for quantized Tensors. Right now we only have
* qint8 which is for 8 bit Tensors, and qint32 for 32 bit int Tensors,
* we might have 4 bit, 2 bit or 1 bit data types in the future.
*/
struct alignas(1) qint8 {
using underlying = int8_t;
int8_t val_;
explicit qint8(int8_t val) : val_(val) {}
};
} // namespace c10

View File

@ -82,8 +82,9 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, float*) CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, float*)
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, at::Half*) CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, at::Half*)
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(29, c10::quint8) CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(29, c10::qint8)
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(30, c10::qint32) CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(30, c10::quint8)
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(31, _CaffeHighestPreallocatedTypeId) CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(31, c10::qint32)
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(32, _CaffeHighestPreallocatedTypeId)
} // namespace caffe2 } // namespace caffe2

View File

@ -23,6 +23,7 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Half.h> #include <c10/util/Half.h>
#include <c10/util/IdWrapper.h> #include <c10/util/IdWrapper.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h> #include <c10/util/quint8.h>
#include <c10/util/qint32.h> #include <c10/util/qint32.h>
#include <c10/util/Type.h> #include <c10/util/Type.h>
@ -625,8 +626,9 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, float*) CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, float*)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, at::Half*) CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, at::Half*)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(29, c10::quint8) CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(29, c10::qint8)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(30, c10::qint32) CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(30, c10::quint8)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(31, _CaffeHighestPreallocatedTypeId) CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(31, c10::qint32)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(32, _CaffeHighestPreallocatedTypeId)
} // namespace caffe2 } // namespace caffe2

View File

@ -18,7 +18,7 @@ class FunctionalAPITest(TestCase):
Y = X.numpy().copy() Y = X.numpy().copy()
Y[Y < 0] = 0 Y[Y < 0] = 0
qY = _quantize(Y, scale, zero_point) qY = _quantize(Y, scale, zero_point)
qX = X.quantize_linear(scale=scale, zero_point=zero_point, dtype=torch.qint8) qX = X.quantize_linear(scale=scale, zero_point=zero_point, dtype=torch.quint8)
qY_hat = F.relu(qX) qY_hat = F.relu(qX)
np.testing.assert_equal(qY, qY_hat.int_repr()) np.testing.assert_equal(qY, qY_hat.int_repr())

View File

@ -2796,6 +2796,9 @@ class _TestTorchMixin(object):
r = torch.from_numpy(r).float() r = torch.from_numpy(r).float()
scale = 2 scale = 2
zero_point = 2 zero_point = 2
qr = r.quantize_linear(scale, zero_point, torch.qint8)
rqr = qr.dequantize()
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
qr = r.quantize_linear(scale, zero_point, torch.quint8) qr = r.quantize_linear(scale, zero_point, torch.quint8)
rqr = qr.dequantize() rqr = qr.dequantize()
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale)) self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))

View File

@ -41,6 +41,8 @@ static std::pair<std::string, std::string> getDtypeNames(
return std::make_pair("complex128", ""); return std::make_pair("complex128", "");
case at::ScalarType::Bool: case at::ScalarType::Bool:
return std::make_pair("bool", ""); return std::make_pair("bool", "");
case at::ScalarType::QInt8:
return std::make_pair("qint8", "");
case at::ScalarType::QUInt8: case at::ScalarType::QUInt8:
return std::make_pair("quint8", ""); return std::make_pair("quint8", "");
case at::ScalarType::QInt32: case at::ScalarType::QInt32: