mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add dequantize_linear for JIT pass (#20107)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20107 att Reviewed By: nishantpdce Differential Revision: D15202187 fbshipit-source-id: 7d6274a67fcca695c0425587f35046fecbc2ccdc
This commit is contained in:
committed by
Facebook Github Bot
parent
cc02a1af61
commit
cca923c481
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/core/DeprecatedTypeProperties.h>
|
||||
@ -11,6 +12,13 @@
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define AT_QINT_PRIVATE_CASE_TYPE(enum_type, type, underlying_type, ...) \
|
||||
case enum_type: { \
|
||||
using scalar_t C10_UNUSED = type; \
|
||||
using underlying_t C10_UNUSED = underlying_type; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <at::ScalarType N>
|
||||
@ -211,14 +219,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
||||
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QInt8, qint8, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QUInt8, quint8, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QInt32, qint32, __VA_ARGS__) \
|
||||
AT_QINT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QInt8, qint8, int8_t, __VA_ARGS__) \
|
||||
AT_QINT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QUInt8, quint8, uint8_t, __VA_ARGS__) \
|
||||
AT_QINT_PRIVATE_CASE_TYPE( \
|
||||
at::ScalarType::QInt32, qint32, int, __VA_ARGS__) \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
|
@ -583,6 +583,7 @@ class CAFFE2_API Tensor {
|
||||
Tensor to_mkldnn() const;
|
||||
Tensor quantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
|
||||
Tensor dequantize() const;
|
||||
Tensor dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
|
||||
Scalar q_scale() const;
|
||||
Scalar q_zero_point() const;
|
||||
Tensor int_repr() const;
|
||||
|
@ -804,6 +804,9 @@ inline Tensor Tensor::quantize_linear(double scale, int64_t zero_point, ScalarTy
|
||||
inline Tensor Tensor::dequantize() const {
|
||||
return dispatch_type().dequantize(*this);
|
||||
}
|
||||
inline Tensor Tensor::dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const {
|
||||
return dispatch_type().dequantize_linear(*this, scale, zero_point, dtype);
|
||||
}
|
||||
inline Scalar Tensor::q_scale() const {
|
||||
return dispatch_type().q_scale(*this);
|
||||
}
|
||||
|
@ -393,6 +393,7 @@ struct CAFFE2_API Type {
|
||||
virtual Tensor to_mkldnn(const Tensor & self) const = 0;
|
||||
virtual Tensor quantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
|
||||
virtual Tensor dequantize(const Tensor & self) const = 0;
|
||||
virtual Tensor dequantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
|
||||
virtual Scalar q_scale(const Tensor & self) const = 0;
|
||||
virtual Scalar q_zero_point(const Tensor & self) const = 0;
|
||||
virtual Tensor int_repr(const Tensor & self) const = 0;
|
||||
|
@ -2551,6 +2551,11 @@
|
||||
dispatch:
|
||||
QuantizedCPU: dequantize_quant
|
||||
|
||||
- func: dequantize_linear(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: dequantize_linear_cpu
|
||||
|
||||
- func: q_scale(Tensor self) -> Scalar
|
||||
variants: function, method
|
||||
dispatch:
|
||||
|
@ -16,6 +16,21 @@ Tensor dequantize_quant(const Tensor& self) {
|
||||
return get_qtensorimpl(self)->quantizer()->dequantize(self);
|
||||
}
|
||||
|
||||
Tensor dequantize_linear_cpu(const Tensor& self, double scale, int64_t zero_point, ScalarType dtype) {
|
||||
AT_CHECK(isQIntType(toQIntType(self.scalar_type())),
|
||||
"Scalar type for quantized Tensor must have same underlying type as input.");
|
||||
AT_CHECK(dtype == ScalarType::Float, "ScalarType for target Tensor must be float.");
|
||||
Tensor f = at::empty(self.sizes(), self.options().dtype(dtype));
|
||||
AT_DISPATCH_QINT_TYPES(
|
||||
toQIntType(self.scalar_type()), "dequantize_linear_cpu", [&]() {
|
||||
underlying_t* qdata = self.data<underlying_t>();
|
||||
auto* fdata = f.data<float>();
|
||||
for (int i = 0; i < self.numel(); ++i) {
|
||||
fdata[i] = (static_cast<float>(qdata[i]) - zero_point) * scale;
|
||||
}});
|
||||
return f;
|
||||
}
|
||||
|
||||
Scalar q_scale_quant(const Tensor& self) {
|
||||
auto quantizer = get_qtensorimpl(self)->quantizer();
|
||||
AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);
|
||||
|
@ -234,6 +234,36 @@ static inline bool isQIntType(ScalarType t) {
|
||||
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
|
||||
}
|
||||
|
||||
static inline ScalarType toQIntType(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::Byte:
|
||||
return ScalarType::QUInt8;
|
||||
case ScalarType::Char:
|
||||
return ScalarType::QInt8;
|
||||
case ScalarType::Int:
|
||||
return ScalarType::QInt32;
|
||||
default:
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
static inline ScalarType toUnderlying(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::QUInt8:
|
||||
return ScalarType::Byte;
|
||||
case ScalarType::QInt8:
|
||||
return ScalarType::Char;
|
||||
case ScalarType::QInt32:
|
||||
return ScalarType::Int;
|
||||
default:
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
|
||||
return type == toUnderlying(qtype);
|
||||
}
|
||||
|
||||
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||
// This is generated according to NumPy's promote_types
|
||||
constexpr auto u1 = ScalarType::Byte;
|
||||
|
@ -209,6 +209,7 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: cumsum
|
||||
.. automethod:: data_ptr
|
||||
.. automethod:: dequantize
|
||||
.. automethod:: dequantize_linear
|
||||
.. automethod:: det
|
||||
.. automethod:: dense_dim
|
||||
.. automethod:: detach
|
||||
|
@ -2804,6 +2804,12 @@ class _TestTorchMixin(object):
|
||||
rqr = qr.dequantize()
|
||||
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
|
||||
|
||||
def test_qtensor_dequantize_linear(self):
|
||||
t = torch.arange(-10, 10, dtype=torch.int8)
|
||||
scale = 3
|
||||
zero_point = 2
|
||||
qt = torch.dequantize_linear(t, scale, zero_point, torch.float)
|
||||
|
||||
|
||||
@unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected')
|
||||
def test_device_guard(self):
|
||||
|
@ -3031,6 +3031,15 @@ det() -> Tensor
|
||||
See :func:`torch.det`
|
||||
""")
|
||||
|
||||
add_docstr_all('dequantize_linear',
|
||||
r"""
|
||||
dequantize_linear(int_tensor, scale, zero_point) -> Tensor
|
||||
|
||||
Dequantize an int Tensor that represents the underlying quantized data
|
||||
using affine quantization scheme with given scale and zero_point.
|
||||
returns a float Tensor.
|
||||
""")
|
||||
|
||||
add_docstr_all('where',
|
||||
r"""
|
||||
where(condition, y) -> Tensor
|
||||
|
Reference in New Issue
Block a user