mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44678 This is a prototype PR that introduces 4 bit qtensors. The new dtype added for this is c10::quint4x2 The underlying storage for this is still uint8_t, so we pack 2 4-bit values in a byte while quantizing it. This change uses most of the existing scaffolding for qtensor storage. We allocate storage based on the dtype before creating a new qtensor. It also adds a dispatch mechanism for this dtype so we can use this to get the bitwidth, qmin and qmax info while quantizing and packing the qtensor (when we add 2-bit qtensor) Kernels that use this dtype should be aware of the packing format. Test Plan: Locally tested ``` x = torch.ones((100, 100), dtype=torch.float) qx_8bit = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint8) qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint4x2) torch.save(x, "temp.p") print('Size float (B):', os.path.getsize("temp.p")) os.remove('temp.p') torch.save(qx_8bit, "temp.p") print('Size quantized 8bit(B):', os.path.getsize("temp.p")) os.remove('temp.p') torch.save(qx, "temp.p") print('Size quantized 4bit(B):', os.path.getsize("temp.p")) os.remove('temp.p') ``` Size float (B): 40760 Size quantized 8bit(B): 10808 Size quantized 4bit(B): 5816 Imported from OSS Reviewed By: raghuramank100 Differential Revision: D23993134 fbshipit-source-id: 073bf262f9680416150ba78ed2d932032275946d
80 lines
3.1 KiB
C
80 lines
3.1 KiB
C
#ifndef THP_STORAGE_INC
|
|
#define THP_STORAGE_INC
|
|
|
|
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
|
|
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
|
|
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
|
|
|
|
#define THPDoubleStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPDoubleStorageClass)
|
|
#define THPFloatStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPFloatStorageClass)
|
|
#define THPHalfStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPFloatStorageClass)
|
|
#define THPLongStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPLongStorageClass)
|
|
#define THPIntStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPIntStorageClass)
|
|
#define THPShortStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPShortStorageClass)
|
|
#define THPCharStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPCharStorageClass)
|
|
#define THPByteStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPByteStorageClass)
|
|
#define THPBoolStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPBoolStorageClass)
|
|
#define THPQUInt8Storage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPQUInt8StorageClass)
|
|
#define THPQInt8Storage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPQInt8StorageClass)
|
|
#define THPQInt32Storage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPQInt32StorageClass)
|
|
#define THPBFloat16Storage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPBFloat16StorageClass)
|
|
#define THPComplexDoubleStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPComplexDoubleStorageClass)
|
|
#define THPComplexFloatStorage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPComplexFloatStorageClass)
|
|
#define THPQUInt4x2Storage_Check(obj) \
|
|
PyObject_IsInstance(obj, THPQUInt8StorageClass)
|
|
|
|
#define THPDoubleStorage_CData(obj) (obj)->cdata
|
|
#define THPFloatStorage_CData(obj) (obj)->cdata
|
|
#define THPHalfStorage_CData(obj) (obj)->cdata
|
|
#define THPLongStorage_CData(obj) (obj)->cdata
|
|
#define THPIntStorage_CData(obj) (obj)->cdata
|
|
#define THPShortStorage_CData(obj) (obj)->cdata
|
|
#define THPCharStorage_CData(obj) (obj)->cdata
|
|
#define THPByteStorage_CData(obj) (obj)->cdata
|
|
#define THPBoolStorage_CData(obj) (obj)->cdata
|
|
#define THPQUInt8Storage_CData(obj) (obj)->cdata
|
|
#define THPQInt8Storage_CData(obj) (obj)->cdata
|
|
#define THPQInt32Storage_CData(obj) (obj)->cdata
|
|
#define THPBFloat16Storage_CData(obj) (obj)->cdata
|
|
#define THPComplexDoubleStorage_CData(obj) (obj)->cdata
|
|
#define THPComplexFloatStorage_CData(obj) (obj)->cdata
|
|
#define THPQUInt4x2Storage_CData(obj) (obj)->cdata
|
|
|
|
#define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
|
|
#define THPStorageBaseStr TH_CONCAT_STRING_2(Real,StorageBase)
|
|
|
|
#include <torch/csrc/generic/Storage.h>
|
|
#include <TH/THGenerateAllTypes.h>
|
|
|
|
#include <torch/csrc/generic/Storage.h>
|
|
#include <TH/THGenerateComplexTypes.h>
|
|
|
|
#include <torch/csrc/generic/Storage.h>
|
|
#include <TH/THGenerateHalfType.h>
|
|
|
|
#include <torch/csrc/generic/Storage.h>
|
|
#include <TH/THGenerateBoolType.h>
|
|
|
|
#include <torch/csrc/generic/Storage.h>
|
|
#include <TH/THGenerateBFloat16Type.h>
|
|
|
|
#include <torch/csrc/generic/Storage.h>
|
|
#include <TH/THGenerateQTypes.h>
|
|
|
|
#endif
|