mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add qint8 type (int8_t) (#19984)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19984 Add qint8 for QTensor, with underlying type of int8_t Reviewed By: jianyuh Differential Revision: D15150715 fbshipit-source-id: 57580f599d46f9323af5ce462dbbc464b25e40d7
This commit is contained in:
committed by
Facebook Github Bot
parent
986c9eb537
commit
85fad0597c
@ -39,6 +39,9 @@ static DLDataType getDLDataType(const Tensor& t) {
|
||||
case ScalarType::Bool:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
throw std::logic_error("QInt8 is not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::QUInt8:
|
||||
throw std::logic_error("QUInt8 is not supported by dlpack");
|
||||
break;
|
||||
|
@ -211,6 +211,8 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
||||
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QInt8, qint8, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QUInt8, quint8, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
|
@ -188,6 +188,7 @@ extension_backends = ['MSNPU', 'XLA']
|
||||
|
||||
# scalar_name, c_type, accreal, is_floating_type
|
||||
quantized_scalar_types = [
|
||||
('QInt8', 'qint8', 'QInt8AccrealNotDefined', 'QInt8IsFloatingTypeNotDefined'),
|
||||
('QUInt8', 'quint8', 'QUInt8AccrealNotDefined', 'QUInt8IsFloatingTypeNotDefined'),
|
||||
('QInt32', 'qint32', 'QInt32AccrealNotDefined', 'Qint32IsFloatingTypeNotDefined'),
|
||||
]
|
||||
|
@ -8,9 +8,6 @@
|
||||
namespace at {
|
||||
namespace native {
|
||||
Tensor& _s_copy__quantized(Tensor& self, const Tensor& src, bool /* unused */) {
|
||||
TORCH_CHECK(
|
||||
self.scalar_type() == at::kQUInt8,
|
||||
"Quantized copy only works with kQUInt8 as target Tensor");
|
||||
TORCH_CHECK(
|
||||
src.scalar_type() == at::kFloat,
|
||||
"Quantized copy only works with kFloat as source Tensor");
|
||||
|
@ -18,6 +18,7 @@ type_map = {
|
||||
'Bool',
|
||||
],
|
||||
'quantized': [
|
||||
'QInt8',
|
||||
'QUInt8',
|
||||
'QInt32',
|
||||
]
|
||||
|
@ -152,11 +152,13 @@ Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, float scale, int32_t ze
|
||||
return rtensor;
|
||||
}
|
||||
#endif
|
||||
|
||||
template CAFFE2_API qint8 quantize_val<qint8>(float scale, int32_t zero_point, float value);
|
||||
template CAFFE2_API quint8 quantize_val<quint8>(float scale, int32_t zero_point, float value);
|
||||
template CAFFE2_API qint32 quantize_val<qint32>(float scale, int32_t zero_point, float value);
|
||||
template CAFFE2_API Tensor quantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
|
||||
template CAFFE2_API Tensor quantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
|
||||
template CAFFE2_API Tensor quantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
|
||||
template CAFFE2_API Tensor dequantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
|
||||
template CAFFE2_API Tensor dequantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
|
||||
template CAFFE2_API Tensor dequantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, float scale, int32_t zero_point);
|
||||
|
||||
|
@ -29,8 +29,9 @@ namespace c10 {
|
||||
_(std::complex<float>, ComplexFloat, z) /* 9 */ \
|
||||
_(std::complex<double>, ComplexDouble, z) /* 10 */ \
|
||||
_(bool, Bool, i) /* 11 */ \
|
||||
_(c10::quint8, QUInt8, i) /* 12 */ \
|
||||
_(c10::qint32, QInt32, i) /* 13 */
|
||||
_(c10::qint8, QInt8, i) /* 12 */ \
|
||||
_(c10::quint8, QUInt8, i) /* 13 */ \
|
||||
_(c10::qint32, QInt32, i) /* 14 */
|
||||
|
||||
// If you want to support ComplexHalf for real, replace occurrences
|
||||
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
|
||||
@ -47,6 +48,7 @@ namespace c10 {
|
||||
_(std::complex<float>, ComplexFloat, z) \
|
||||
_(std::complex<double>, ComplexDouble, z) \
|
||||
_(bool, Bool, i) \
|
||||
_(c10::qint8, QInt8, i) \
|
||||
_(c10::quint8, QUInt8, i) \
|
||||
_(c10::qint32, QInt32, i)
|
||||
|
||||
@ -72,6 +74,7 @@ namespace c10 {
|
||||
_(at::Half, Half, d) \
|
||||
_(float, Float, d) \
|
||||
_(double, Double, d) \
|
||||
_(c10::qint8, QInt8, i) \
|
||||
_(c10::quint8, QUInt8, i) \
|
||||
_(c10::qint32, QInt32, i)
|
||||
|
||||
@ -104,6 +107,7 @@ namespace c10 {
|
||||
_(int64_t, Long, i) \
|
||||
_(float, Float, d) \
|
||||
_(double, Double, d) \
|
||||
_(c10::qint8, QInt8, i) \
|
||||
_(c10::quint8, QUInt8, i) \
|
||||
_(c10::qint32, QInt32, i)
|
||||
|
||||
@ -227,7 +231,7 @@ static inline bool isComplexType(ScalarType t) {
|
||||
|
||||
static inline bool isQIntType(ScalarType t) {
|
||||
// Don't forget to extend this when adding new QInt types
|
||||
return t == ScalarType::QUInt8 || t == ScalarType::QInt32;
|
||||
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
|
||||
}
|
||||
|
||||
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||
|
17
c10/util/qint8.h
Normal file
17
c10/util/qint8.h
Normal file
@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/**
|
||||
* This is the data type for quantized Tensors. Right now we only have
|
||||
* qint8 which is for 8 bit Tensors, and qint32 for 32 bit int Tensors,
|
||||
* we might have 4 bit, 2 bit or 1 bit data types in the future.
|
||||
*/
|
||||
struct alignas(1) qint8 {
|
||||
using underlying = int8_t;
|
||||
int8_t val_;
|
||||
explicit qint8(int8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
} // namespace c10
|
@ -82,8 +82,9 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(
|
||||
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, float*)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, at::Half*)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(29, c10::quint8)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(30, c10::qint32)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(31, _CaffeHighestPreallocatedTypeId)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(29, c10::qint8)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(30, c10::quint8)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(31, c10::qint32)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(32, _CaffeHighestPreallocatedTypeId)
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/IdWrapper.h>
|
||||
#include <c10/util/qint8.h>
|
||||
#include <c10/util/quint8.h>
|
||||
#include <c10/util/qint32.h>
|
||||
#include <c10/util/Type.h>
|
||||
@ -625,8 +626,9 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(
|
||||
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, float*)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, at::Half*)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(29, c10::quint8)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(30, c10::qint32)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(31, _CaffeHighestPreallocatedTypeId)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(29, c10::qint8)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(30, c10::quint8)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(31, c10::qint32)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(32, _CaffeHighestPreallocatedTypeId)
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -18,7 +18,7 @@ class FunctionalAPITest(TestCase):
|
||||
Y = X.numpy().copy()
|
||||
Y[Y < 0] = 0
|
||||
qY = _quantize(Y, scale, zero_point)
|
||||
qX = X.quantize_linear(scale=scale, zero_point=zero_point, dtype=torch.qint8)
|
||||
qX = X.quantize_linear(scale=scale, zero_point=zero_point, dtype=torch.quint8)
|
||||
qY_hat = F.relu(qX)
|
||||
np.testing.assert_equal(qY, qY_hat.int_repr())
|
||||
|
||||
|
@ -2796,6 +2796,9 @@ class _TestTorchMixin(object):
|
||||
r = torch.from_numpy(r).float()
|
||||
scale = 2
|
||||
zero_point = 2
|
||||
qr = r.quantize_linear(scale, zero_point, torch.qint8)
|
||||
rqr = qr.dequantize()
|
||||
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
|
||||
qr = r.quantize_linear(scale, zero_point, torch.quint8)
|
||||
rqr = qr.dequantize()
|
||||
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
|
||||
|
@ -41,6 +41,8 @@ static std::pair<std::string, std::string> getDtypeNames(
|
||||
return std::make_pair("complex128", "");
|
||||
case at::ScalarType::Bool:
|
||||
return std::make_pair("bool", "");
|
||||
case at::ScalarType::QInt8:
|
||||
return std::make_pair("qint8", "");
|
||||
case at::ScalarType::QUInt8:
|
||||
return std::make_pair("quint8", "");
|
||||
case at::ScalarType::QInt32:
|
||||
|
Reference in New Issue
Block a user