mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add as_quantized_tensor (#20740)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20740 Provide a way to assemble quantized Tensor from int8 Tensor, scale and zero point. Differential Revision: D15232416 fbshipit-source-id: c3a3d9d7214b1dc569214c019440c2779fbd063b
This commit is contained in:
committed by
Facebook Github Bot
parent
12bc81ae2a
commit
9ea009fe8b
@ -2573,6 +2573,10 @@
|
||||
dispatch:
|
||||
QuantizedCPU: int_repr_quant
|
||||
|
||||
- func: _per_tensor_affine_qtensor(Tensor self, float scale, int zero_point) -> Tensor
|
||||
dispatch:
|
||||
CPU: per_tensor_affine_qtensor_cpu
|
||||
|
||||
# to(Device) must not exist because all constructors of Device also works for
|
||||
# TensorOptions. Otherwise, an ambiguity error is thrown.
|
||||
# See NOTE [ TensorOptions Constructors ].
|
||||
|
@ -60,5 +60,17 @@ Tensor int_repr_quant(const Tensor& self) {
|
||||
return dst;
|
||||
}
|
||||
|
||||
Tensor per_tensor_affine_qtensor_cpu(const Tensor& self, double scale, int64_t zero_point) {
|
||||
Tensor dst = at::_empty_affine_quantized(self.sizes(), self.options().dtype(toQIntType(self.scalar_type())), scale, zero_point);
|
||||
AT_DISPATCH_QINT_TYPES(dst.scalar_type(), "per_tensor_affine_qtensor", [&]() {
|
||||
underlying_t* self_data = self.data<underlying_t>();
|
||||
underlying_t* dst_data = reinterpret_cast<underlying_t *>(dst.data<scalar_t>());
|
||||
if (self.numel() > 0) {
|
||||
memcpy(dst_data, self_data, self.numel());
|
||||
}
|
||||
});
|
||||
return dst;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -2798,8 +2798,16 @@ class _TestTorchMixin(object):
|
||||
zero_point = 10
|
||||
val = 100
|
||||
numel = 10
|
||||
q = torch._empty_affine_quantized(numel, dtype=torch.quint8, scale=scale, zero_point=zero_point)
|
||||
# TODO: check dequantized values?
|
||||
q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
|
||||
self.assertEqual(scale, q.q_scale())
|
||||
self.assertEqual(zero_point, q.q_zero_point())
|
||||
|
||||
# create Tensor from uint8_t Tensor, scale and zero_point
|
||||
int_tensor = torch.randint(0, 100, size=(10,), dtype=torch.uint8)
|
||||
q = torch._per_tensor_affine_qtensor(int_tensor, scale, zero_point)
|
||||
self.assertEqual(int_tensor, q.int_repr())
|
||||
self.assertEqual(scale, q.q_scale())
|
||||
self.assertEqual(zero_point, q.q_zero_point())
|
||||
|
||||
def test_qtensor_dtypes(self):
|
||||
r = np.random.rand(3, 2) * 2 - 4
|
||||
|
Reference in New Issue
Block a user