mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add dequantize_linear for JIT pass (#20107)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20107 att Reviewed By: nishantpdce Differential Revision: D15202187 fbshipit-source-id: 7d6274a67fcca695c0425587f35046fecbc2ccdc
This commit is contained in:
committed by
Facebook Github Bot
parent
cc02a1af61
commit
cca923c481
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/Half.h>
|
#include <c10/util/Half.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <ATen/core/DeprecatedTypeProperties.h>
|
#include <ATen/core/DeprecatedTypeProperties.h>
|
||||||
@ -11,6 +12,13 @@
|
|||||||
return __VA_ARGS__(); \
|
return __VA_ARGS__(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define AT_QINT_PRIVATE_CASE_TYPE(enum_type, type, underlying_type, ...) \
|
||||||
|
case enum_type: { \
|
||||||
|
using scalar_t C10_UNUSED = type; \
|
||||||
|
using underlying_t C10_UNUSED = underlying_type; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
template <at::ScalarType N>
|
template <at::ScalarType N>
|
||||||
@ -211,14 +219,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
|||||||
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
|
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
|
||||||
[&] { \
|
[&] { \
|
||||||
switch (TYPE) { \
|
switch (TYPE) { \
|
||||||
AT_PRIVATE_CASE_TYPE( \
|
AT_QINT_PRIVATE_CASE_TYPE( \
|
||||||
at::ScalarType::QInt8, qint8, __VA_ARGS__) \
|
at::ScalarType::QInt8, qint8, int8_t, __VA_ARGS__) \
|
||||||
AT_PRIVATE_CASE_TYPE( \
|
AT_QINT_PRIVATE_CASE_TYPE( \
|
||||||
at::ScalarType::QUInt8, quint8, __VA_ARGS__) \
|
at::ScalarType::QUInt8, quint8, uint8_t, __VA_ARGS__) \
|
||||||
AT_PRIVATE_CASE_TYPE( \
|
AT_QINT_PRIVATE_CASE_TYPE( \
|
||||||
at::ScalarType::QInt32, qint32, __VA_ARGS__) \
|
at::ScalarType::QInt32, qint32, int, __VA_ARGS__) \
|
||||||
default: \
|
default: \
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
} \
|
} \
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -583,6 +583,7 @@ class CAFFE2_API Tensor {
|
|||||||
Tensor to_mkldnn() const;
|
Tensor to_mkldnn() const;
|
||||||
Tensor quantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
|
Tensor quantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
|
||||||
Tensor dequantize() const;
|
Tensor dequantize() const;
|
||||||
|
Tensor dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
|
||||||
Scalar q_scale() const;
|
Scalar q_scale() const;
|
||||||
Scalar q_zero_point() const;
|
Scalar q_zero_point() const;
|
||||||
Tensor int_repr() const;
|
Tensor int_repr() const;
|
||||||
|
@ -804,6 +804,9 @@ inline Tensor Tensor::quantize_linear(double scale, int64_t zero_point, ScalarTy
|
|||||||
inline Tensor Tensor::dequantize() const {
|
inline Tensor Tensor::dequantize() const {
|
||||||
return dispatch_type().dequantize(*this);
|
return dispatch_type().dequantize(*this);
|
||||||
}
|
}
|
||||||
|
inline Tensor Tensor::dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const {
|
||||||
|
return dispatch_type().dequantize_linear(*this, scale, zero_point, dtype);
|
||||||
|
}
|
||||||
inline Scalar Tensor::q_scale() const {
|
inline Scalar Tensor::q_scale() const {
|
||||||
return dispatch_type().q_scale(*this);
|
return dispatch_type().q_scale(*this);
|
||||||
}
|
}
|
||||||
|
@ -393,6 +393,7 @@ struct CAFFE2_API Type {
|
|||||||
virtual Tensor to_mkldnn(const Tensor & self) const = 0;
|
virtual Tensor to_mkldnn(const Tensor & self) const = 0;
|
||||||
virtual Tensor quantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
|
virtual Tensor quantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
|
||||||
virtual Tensor dequantize(const Tensor & self) const = 0;
|
virtual Tensor dequantize(const Tensor & self) const = 0;
|
||||||
|
virtual Tensor dequantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
|
||||||
virtual Scalar q_scale(const Tensor & self) const = 0;
|
virtual Scalar q_scale(const Tensor & self) const = 0;
|
||||||
virtual Scalar q_zero_point(const Tensor & self) const = 0;
|
virtual Scalar q_zero_point(const Tensor & self) const = 0;
|
||||||
virtual Tensor int_repr(const Tensor & self) const = 0;
|
virtual Tensor int_repr(const Tensor & self) const = 0;
|
||||||
|
@ -2551,6 +2551,11 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
QuantizedCPU: dequantize_quant
|
QuantizedCPU: dequantize_quant
|
||||||
|
|
||||||
|
- func: dequantize_linear(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor
|
||||||
|
variants: function, method
|
||||||
|
dispatch:
|
||||||
|
CPU: dequantize_linear_cpu
|
||||||
|
|
||||||
- func: q_scale(Tensor self) -> Scalar
|
- func: q_scale(Tensor self) -> Scalar
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
|
@ -16,6 +16,21 @@ Tensor dequantize_quant(const Tensor& self) {
|
|||||||
return get_qtensorimpl(self)->quantizer()->dequantize(self);
|
return get_qtensorimpl(self)->quantizer()->dequantize(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor dequantize_linear_cpu(const Tensor& self, double scale, int64_t zero_point, ScalarType dtype) {
|
||||||
|
AT_CHECK(isQIntType(toQIntType(self.scalar_type())),
|
||||||
|
"Scalar type for quantized Tensor must have same underlying type as input.");
|
||||||
|
AT_CHECK(dtype == ScalarType::Float, "ScalarType for target Tensor must be float.");
|
||||||
|
Tensor f = at::empty(self.sizes(), self.options().dtype(dtype));
|
||||||
|
AT_DISPATCH_QINT_TYPES(
|
||||||
|
toQIntType(self.scalar_type()), "dequantize_linear_cpu", [&]() {
|
||||||
|
underlying_t* qdata = self.data<underlying_t>();
|
||||||
|
auto* fdata = f.data<float>();
|
||||||
|
for (int i = 0; i < self.numel(); ++i) {
|
||||||
|
fdata[i] = (static_cast<float>(qdata[i]) - zero_point) * scale;
|
||||||
|
}});
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
Scalar q_scale_quant(const Tensor& self) {
|
Scalar q_scale_quant(const Tensor& self) {
|
||||||
auto quantizer = get_qtensorimpl(self)->quantizer();
|
auto quantizer = get_qtensorimpl(self)->quantizer();
|
||||||
AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);
|
AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);
|
||||||
|
@ -234,6 +234,36 @@ static inline bool isQIntType(ScalarType t) {
|
|||||||
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
|
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline ScalarType toQIntType(ScalarType t) {
|
||||||
|
switch (t) {
|
||||||
|
case ScalarType::Byte:
|
||||||
|
return ScalarType::QUInt8;
|
||||||
|
case ScalarType::Char:
|
||||||
|
return ScalarType::QInt8;
|
||||||
|
case ScalarType::Int:
|
||||||
|
return ScalarType::QInt32;
|
||||||
|
default:
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline ScalarType toUnderlying(ScalarType t) {
|
||||||
|
switch (t) {
|
||||||
|
case ScalarType::QUInt8:
|
||||||
|
return ScalarType::Byte;
|
||||||
|
case ScalarType::QInt8:
|
||||||
|
return ScalarType::Char;
|
||||||
|
case ScalarType::QInt32:
|
||||||
|
return ScalarType::Int;
|
||||||
|
default:
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
|
||||||
|
return type == toUnderlying(qtype);
|
||||||
|
}
|
||||||
|
|
||||||
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||||
// This is generated according to NumPy's promote_types
|
// This is generated according to NumPy's promote_types
|
||||||
constexpr auto u1 = ScalarType::Byte;
|
constexpr auto u1 = ScalarType::Byte;
|
||||||
|
@ -209,6 +209,7 @@ view of a storage and defines numeric operations on it.
|
|||||||
.. automethod:: cumsum
|
.. automethod:: cumsum
|
||||||
.. automethod:: data_ptr
|
.. automethod:: data_ptr
|
||||||
.. automethod:: dequantize
|
.. automethod:: dequantize
|
||||||
|
.. automethod:: dequantize_linear
|
||||||
.. automethod:: det
|
.. automethod:: det
|
||||||
.. automethod:: dense_dim
|
.. automethod:: dense_dim
|
||||||
.. automethod:: detach
|
.. automethod:: detach
|
||||||
|
@ -2804,6 +2804,12 @@ class _TestTorchMixin(object):
|
|||||||
rqr = qr.dequantize()
|
rqr = qr.dequantize()
|
||||||
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
|
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
|
||||||
|
|
||||||
|
def test_qtensor_dequantize_linear(self):
|
||||||
|
t = torch.arange(-10, 10, dtype=torch.int8)
|
||||||
|
scale = 3
|
||||||
|
zero_point = 2
|
||||||
|
qt = torch.dequantize_linear(t, scale, zero_point, torch.float)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected')
|
@unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected')
|
||||||
def test_device_guard(self):
|
def test_device_guard(self):
|
||||||
|
@ -3031,6 +3031,15 @@ det() -> Tensor
|
|||||||
See :func:`torch.det`
|
See :func:`torch.det`
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
add_docstr_all('dequantize_linear',
|
||||||
|
r"""
|
||||||
|
dequantize_linear(int_tensor, scale, zero_point) -> Tensor
|
||||||
|
|
||||||
|
Dequantize an int Tensor that represents the underlying quantized data
|
||||||
|
using affine quantization scheme with given scale and zero_point.
|
||||||
|
returns a float Tensor.
|
||||||
|
""")
|
||||||
|
|
||||||
add_docstr_all('where',
|
add_docstr_all('where',
|
||||||
r"""
|
r"""
|
||||||
where(condition, y) -> Tensor
|
where(condition, y) -> Tensor
|
||||||
|
Reference in New Issue
Block a user