Files
pytorch/torch/csrc/Storage.h
Supriya Rao 04526a49d3 [quant] creating quint4x2 dtype for quantized tensors (#44678)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44678

This is a prototype PR that introduces 4 bit qtensors. The new dtype added for this is c10::quint4x2
The underlying storage for this is still uint8_t, so we pack 2 4-bit values in a byte while quantizing it.

This change uses most of the existing scaffolding for qtensor storage. We allocate storage
based on the dtype before creating a new qtensor.

It also adds a dispatch mechanism for this dtype so we can use this to get the bitwidth, qmin and qmax info
while quantizing and packing the qtensor (when we add 2-bit qtensor)

Kernels that use this dtype should be aware of the packing format.

Test Plan:
Locally tested
```
x = torch.ones((100, 100), dtype=torch.float)
qx_8bit = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint8)
qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint4x2)

torch.save(x, "temp.p")
print('Size float (B):', os.path.getsize("temp.p"))
os.remove('temp.p')

torch.save(qx_8bit, "temp.p")
print('Size quantized 8bit(B):', os.path.getsize("temp.p"))
os.remove('temp.p')

torch.save(qx, "temp.p")
print('Size quantized 4bit(B):', os.path.getsize("temp.p"))
os.remove('temp.p')
```

Size float (B): 40760
Size quantized 8bit(B): 10808
Size quantized 4bit(B): 5816

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D23993134

fbshipit-source-id: 073bf262f9680416150ba78ed2d932032275946d
2020-10-01 23:53:34 -07:00

80 lines
3.1 KiB
C

#ifndef THP_STORAGE_INC
#define THP_STORAGE_INC
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
#define THPDoubleStorage_Check(obj) \
PyObject_IsInstance(obj, THPDoubleStorageClass)
#define THPFloatStorage_Check(obj) \
PyObject_IsInstance(obj, THPFloatStorageClass)
#define THPHalfStorage_Check(obj) \
PyObject_IsInstance(obj, THPFloatStorageClass)
#define THPLongStorage_Check(obj) \
PyObject_IsInstance(obj, THPLongStorageClass)
#define THPIntStorage_Check(obj) \
PyObject_IsInstance(obj, THPIntStorageClass)
#define THPShortStorage_Check(obj) \
PyObject_IsInstance(obj, THPShortStorageClass)
#define THPCharStorage_Check(obj) \
PyObject_IsInstance(obj, THPCharStorageClass)
#define THPByteStorage_Check(obj) \
PyObject_IsInstance(obj, THPByteStorageClass)
#define THPBoolStorage_Check(obj) \
PyObject_IsInstance(obj, THPBoolStorageClass)
#define THPQUInt8Storage_Check(obj) \
PyObject_IsInstance(obj, THPQUInt8StorageClass)
#define THPQInt8Storage_Check(obj) \
PyObject_IsInstance(obj, THPQInt8StorageClass)
#define THPQInt32Storage_Check(obj) \
PyObject_IsInstance(obj, THPQInt32StorageClass)
#define THPBFloat16Storage_Check(obj) \
PyObject_IsInstance(obj, THPBFloat16StorageClass)
#define THPComplexDoubleStorage_Check(obj) \
PyObject_IsInstance(obj, THPComplexDoubleStorageClass)
#define THPComplexFloatStorage_Check(obj) \
PyObject_IsInstance(obj, THPComplexFloatStorageClass)
#define THPQUInt4x2Storage_Check(obj) \
PyObject_IsInstance(obj, THPQUInt8StorageClass)
#define THPDoubleStorage_CData(obj) (obj)->cdata
#define THPFloatStorage_CData(obj) (obj)->cdata
#define THPHalfStorage_CData(obj) (obj)->cdata
#define THPLongStorage_CData(obj) (obj)->cdata
#define THPIntStorage_CData(obj) (obj)->cdata
#define THPShortStorage_CData(obj) (obj)->cdata
#define THPCharStorage_CData(obj) (obj)->cdata
#define THPByteStorage_CData(obj) (obj)->cdata
#define THPBoolStorage_CData(obj) (obj)->cdata
#define THPQUInt8Storage_CData(obj) (obj)->cdata
#define THPQInt8Storage_CData(obj) (obj)->cdata
#define THPQInt32Storage_CData(obj) (obj)->cdata
#define THPBFloat16Storage_CData(obj) (obj)->cdata
#define THPComplexDoubleStorage_CData(obj) (obj)->cdata
#define THPComplexFloatStorage_CData(obj) (obj)->cdata
#define THPQUInt4x2Storage_CData(obj) (obj)->cdata
#define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
#define THPStorageBaseStr TH_CONCAT_STRING_2(Real,StorageBase)
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateAllTypes.h>
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateComplexTypes.h>
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateHalfType.h>
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateBoolType.h>
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateBFloat16Type.h>
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateQTypes.h>
#endif