mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44678 This is a prototype PR that introduces 4 bit qtensors. The new dtype added for this is c10::quint4x2 The underlying storage for this is still uint8_t, so we pack 2 4-bit values in a byte while quantizing it. This change uses most of the existing scaffolding for qtensor storage. We allocate storage based on the dtype before creating a new qtensor. It also adds a dispatch mechanism for this dtype so we can use this to get the bitwidth, qmin and qmax info while quantizing and packing the qtensor (when we add 2-bit qtensor) Kernels that use this dtype should be aware of the packing format. Test Plan: Locally tested ``` x = torch.ones((100, 100), dtype=torch.float) qx_8bit = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint8) qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint4x2) torch.save(x, "temp.p") print('Size float (B):', os.path.getsize("temp.p")) os.remove('temp.p') torch.save(qx_8bit, "temp.p") print('Size quantized 8bit(B):', os.path.getsize("temp.p")) os.remove('temp.p') torch.save(qx, "temp.p") print('Size quantized 4bit(B):', os.path.getsize("temp.p")) os.remove('temp.p') ``` Size float (B): 40760 Size quantized 8bit(B): 10808 Size quantized 4bit(B): 5816 Imported from OSS Reviewed By: raghuramank100 Differential Revision: D23993134 fbshipit-source-id: 073bf262f9680416150ba78ed2d932032275946d
199 lines
9.8 KiB
C++
199 lines
9.8 KiB
C++
#ifndef THP_UTILS_H
|
|
#define THP_UTILS_H
|
|
|
|
#include <vector>
|
|
#include <string>
|
|
#include <type_traits>
|
|
#include <ATen/ATen.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/utils/python_compat.h>
|
|
|
|
#ifdef USE_CUDA
|
|
#include <THC/THC.h>
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#endif
|
|
|
|
#define THPUtils_(NAME) TH_CONCAT_4(THP,Real,Utils_,NAME)
|
|
|
|
#define THPUtils_typename(obj) (Py_TYPE(obj)->tp_name)
|
|
|
|
#if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
|
|
#define THP_EXPECT(x, y) (__builtin_expect((x), (y)))
|
|
#else
|
|
#define THP_EXPECT(x, y) (x)
|
|
#endif
|
|
|
|
#define THPUtils_checkReal_FLOAT(object) \
|
|
(PyFloat_Check(object) || PyLong_Check(object))
|
|
|
|
#define THPUtils_unpackReal_FLOAT(object) \
|
|
(PyFloat_Check(object) ? PyFloat_AsDouble(object) : \
|
|
PyLong_Check(object) ? PyLong_AsLongLong(object) : \
|
|
(throw std::runtime_error("Could not parse real"), 0))
|
|
|
|
#define THPUtils_checkReal_INT(object) \
|
|
PyLong_Check(object)
|
|
|
|
#define THPUtils_unpackReal_INT(object) \
|
|
(PyLong_Check(object) ? PyLong_AsLongLong(object) : \
|
|
(throw std::runtime_error("Could not parse real"), 0))
|
|
|
|
#define THPUtils_unpackReal_BOOL(object) \
|
|
(PyBool_Check(object) ? object : \
|
|
(throw std::runtime_error("Could not parse real"), Py_False))
|
|
|
|
#define THPUtils_unpackReal_COMPLEX(object) \
|
|
(PyComplex_Check(object) ? \
|
|
(c10::complex<double>(PyComplex_RealAsDouble(object), PyComplex_ImagAsDouble(object))) : \
|
|
PyFloat_Check(object) ? (c10::complex<double>(PyFloat_AsDouble(object), 0)) : \
|
|
PyLong_Check(object) ? (c10::complex<double>(PyLong_AsLongLong(object), 0)) : \
|
|
(throw std::runtime_error("Could not parse real"), c10::complex<double>(0,0))) \
|
|
|
|
#define THPUtils_checkReal_BOOL(object) \
|
|
PyBool_Check(object)
|
|
|
|
#define THPUtils_checkReal_COMPLEX(object) \
|
|
PyComplex_Check(object) || PyFloat_Check(object) || PyLong_Check(object) || PyInt_Check(object)
|
|
|
|
#define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value)
|
|
#define THPUtils_newReal_INT(value) PyInt_FromLong(value)
|
|
|
|
#define THPUtils_newReal_BOOL(value) PyBool_FromLong(value)
|
|
|
|
#define THPUtils_newReal_COMPLEX(value) PyComplex_FromDoubles(value.real(), value.imag())
|
|
|
|
#define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
|
|
#define THPDoubleUtils_unpackReal(object) (double)THPUtils_unpackReal_FLOAT(object)
|
|
#define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value)
|
|
#define THPFloatUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
|
|
#define THPFloatUtils_unpackReal(object) (float)THPUtils_unpackReal_FLOAT(object)
|
|
#define THPFloatUtils_newReal(value) THPUtils_newReal_FLOAT(value)
|
|
#define THPHalfUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
|
|
#define THPHalfUtils_unpackReal(object) (at::Half)THPUtils_unpackReal_FLOAT(object)
|
|
#define THPHalfUtils_newReal(value) PyFloat_FromDouble(value)
|
|
#define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value)
|
|
#define THPComplexDoubleUtils_checkReal(object) THPUtils_checkReal_COMPLEX(object)
|
|
#define THPComplexDoubleUtils_unpackReal(object) THPUtils_unpackReal_COMPLEX(object)
|
|
#define THPComplexDoubleUtils_newReal(value) THPUtils_newReal_COMPLEX(value)
|
|
#define THPComplexFloatUtils_checkReal(object) THPUtils_checkReal_COMPLEX(object)
|
|
#define THPComplexFloatUtils_unpackReal(object) (c10::complex<float>)THPUtils_unpackReal_COMPLEX(object)
|
|
#define THPComplexFloatUtils_newReal(value) THPUtils_newReal_COMPLEX(value)
|
|
#define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object)
|
|
#define THPBFloat16Utils_unpackReal(object) (at::BFloat16)THPUtils_unpackReal_FLOAT(object)
|
|
#define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value)
|
|
#define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value)
|
|
|
|
#define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object)
|
|
#define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object)
|
|
#define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value)
|
|
#define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object)
|
|
#define THPBoolUtils_unpackAccreal(object) (int64_t)THPUtils_unpackReal_BOOL(object)
|
|
#define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value)
|
|
#define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPLongUtils_unpackReal(object) (int64_t)THPUtils_unpackReal_INT(object)
|
|
#define THPLongUtils_newReal(value) THPUtils_newReal_INT(value)
|
|
#define THPIntUtils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPIntUtils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
|
|
#define THPIntUtils_newReal(value) THPUtils_newReal_INT(value)
|
|
#define THPShortUtils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPShortUtils_unpackReal(object) (short)THPUtils_unpackReal_INT(object)
|
|
#define THPShortUtils_newReal(value) THPUtils_newReal_INT(value)
|
|
#define THPCharUtils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPCharUtils_unpackReal(object) (char)THPUtils_unpackReal_INT(object)
|
|
#define THPCharUtils_newReal(value) THPUtils_newReal_INT(value)
|
|
#define THPByteUtils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPByteUtils_unpackReal(object) (unsigned char)THPUtils_unpackReal_INT(object)
|
|
#define THPByteUtils_newReal(value) THPUtils_newReal_INT(value)
|
|
// quantized types
|
|
#define THPQUInt8Utils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPQUInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
|
|
#define THPQUInt8Utils_newReal(value) THPUtils_newReal_INT(value)
|
|
#define THPQInt8Utils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPQInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
|
|
#define THPQInt8Utils_newReal(value) THPUtils_newReal_INT(value)
|
|
#define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
|
|
#define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value)
|
|
#define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object)
|
|
#define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
|
|
#define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value)
|
|
|
|
|
|
#define THPUtils_assert(cond, ...) THPUtils_assertRet(nullptr, cond, __VA_ARGS__)
|
|
#define THPUtils_assertRet(value, cond, ...) \
|
|
if (THP_EXPECT(!(cond), 0)) { THPUtils_setError(__VA_ARGS__); return value; }
|
|
THP_API void THPUtils_setError(const char *format, ...);
|
|
THP_API void THPUtils_invalidArguments(
|
|
PyObject *given_args, PyObject *given_kwargs,
|
|
const char *function_name, size_t num_options, ...);
|
|
|
|
bool THPUtils_checkIntTuple(PyObject *arg);
|
|
std::vector<int> THPUtils_unpackIntTuple(PyObject *arg);
|
|
|
|
void THPUtils_addPyMethodDefs(std::vector<PyMethodDef>& vector, PyMethodDef* methods);
|
|
|
|
int THPUtils_getCallable(PyObject *arg, PyObject **result);
|
|
|
|
#define THWStoragePtr TH_CONCAT_3(TH,Real,StoragePtr)
|
|
#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr)
|
|
#define THPStoragePtr TH_CONCAT_3(THP,Real,StoragePtr)
|
|
#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr)
|
|
#define THSPTensorPtr TH_CONCAT_3(THSP,Real,TensorPtr)
|
|
|
|
typedef THPPointer<THPGenerator> THPGeneratorPtr;
|
|
|
|
template <typename T>
|
|
struct THPUtils_typeTraits {};
|
|
|
|
#include <torch/csrc/generic/utils.h>
|
|
#include <TH/THGenerateAllTypes.h>
|
|
|
|
#include <torch/csrc/generic/utils.h>
|
|
#include <TH/THGenerateComplexTypes.h>
|
|
|
|
#include <torch/csrc/generic/utils.h>
|
|
#include <TH/THGenerateHalfType.h>
|
|
|
|
#include <torch/csrc/generic/utils.h>
|
|
#include <TH/THGenerateBFloat16Type.h>
|
|
|
|
#include <torch/csrc/generic/utils.h>
|
|
#include <TH/THGenerateBoolType.h>
|
|
|
|
#include <torch/csrc/generic/utils.h>
|
|
#include <TH/THGenerateQTypes.h>
|
|
|
|
THLongStoragePtr THPUtils_unpackSize(PyObject *arg);
|
|
bool THPUtils_tryUnpackLongs(PyObject *arg, THLongStoragePtr& result);
|
|
std::vector<int64_t> THPUtils_unpackLongs(PyObject *arg);
|
|
bool THPUtils_tryUnpackLongVarArgs(PyObject *args, int ignore_first, THLongStoragePtr& result);
|
|
PyObject * THPUtils_dispatchStateless(PyObject *tensor, const char *name, PyObject *args, PyObject *kwargs);
|
|
|
|
template<typename _real, typename = void>
|
|
struct mod_traits {};
|
|
|
|
template<typename _real>
|
|
struct mod_traits<_real, typename std::enable_if<std::is_floating_point<_real>::value>::type> {
|
|
static _real mod(_real a, _real b) { return fmod(a, b); }
|
|
};
|
|
|
|
template<typename _real>
|
|
struct mod_traits<_real, typename std::enable_if<std::is_integral<_real>::value>::type> {
|
|
static _real mod(_real a, _real b) { return a % b; }
|
|
};
|
|
|
|
void setBackCompatBroadcastWarn(bool warn);
|
|
bool getBackCompatBroadcastWarn();
|
|
|
|
void setBackCompatKeepdimWarn(bool warn);
|
|
bool getBackCompatKeepdimWarn();
|
|
bool maybeThrowBackCompatKeepdimWarn(char *func);
|
|
|
|
// NB: This is in torch/csrc/cuda/utils.cpp, for whatever reason
|
|
#ifdef USE_CUDA
|
|
std::vector<c10::optional<at::cuda::CUDAStream>> THPUtils_PySequence_to_CUDAStreamList(PyObject *obj);
|
|
#endif
|
|
|
|
#endif
|