Add unsigned integer dtypes to PyTorch (#116594)

The dtypes are very useless right now (not even fill works), but it makes torch.uint16, uint32 and uint64 available as a dtype.

Towards https://github.com/pytorch/pytorch/issues/58734

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116594
Approved by: https://github.com/albanD
ghstack dependencies: #116698, #116693
This commit is contained in:
Edward Z. Yang
2024-01-05 06:26:01 -05:00
committed by PyTorch MergeBot
parent 8e273e23b5
commit 8ddac14a15
5 changed files with 56 additions and 8 deletions

View File

@ -10,6 +10,9 @@ DLDataType getDLDataType(const Tensor& t) {
dtype.bits = t.element_size() * 8;
switch (t.scalar_type()) {
case ScalarType::Byte:
case ScalarType::UInt16:
case ScalarType::UInt32:
case ScalarType::UInt64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::Char:

View File

@ -74,6 +74,10 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
toString(b));
}
if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
return ScalarType::Undefined;
}
auto ix_a = dtype2index[static_cast<int64_t>(a)];
TORCH_INTERNAL_ASSERT(ix_a != -1);
auto ix_b = dtype2index[static_cast<int64_t>(b)];

View File

@ -26,11 +26,24 @@
namespace c10 {
// For the macros below:
// NB: If you want to macro some code for all non-QInt scalar types (i.e. types
// with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
// macros below, which are designed to behave similarly to the Dispatch macros
// with the same name.
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
// designed to behave similarly to the Dispatch macros with the same name.
//
// For adding a new dtype: In the beginning, we had an idea that there was a
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
// iterate over them. But over the years we added weird types which couldn't
// be handled uniformly everywhere and so in the end we ended up with some
// mish-mosh of some helper macros, but mostly use sites making a call about
// what dtypes they can or can't support. So if you want to add a new dtype,
// the preferred resolution is to find a dtype similar to what you want,
// grep for it and edit all the sites you find this way. If you need to add
// a completely new kind of dtype, you're going to have to laboriously audit
// all of the sites everywhere to figure out how it should work. Consulting
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
@ -61,11 +74,18 @@ namespace c10 {
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
//
// TODO: To add unsigned int types here, we must define accumulate type.
// But uint8 currently accumulates into int64, so we would have to make
// an inconsistent choice for the larger types. Difficult.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
@ -82,6 +102,8 @@ namespace c10 {
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
@ -157,6 +179,8 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
// NB: despite its generic sounding name, the macros that don't take _AND
// are mostly only used by tensorexpr
#define AT_FORALL_INT_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
@ -173,6 +197,11 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
_(float, Float) \
_(double, Double)
// These macros are often controlling how many template instantiations we
// create for kernels. It is typically inappropriate to add new dtypes here,
// instead, new types should be added to use sites on a case-by-case basis.
// We generally are not accepting new dtypes due to binary size concerns.
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
@ -384,7 +413,9 @@ static inline size_t elementSize(ScalarType t) {
static inline bool isIntegralType(ScalarType t, bool includeBool) {
bool isIntegral =
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
t == ScalarType::Long || t == ScalarType::Short);
t == ScalarType::Long || t == ScalarType::Short ||
t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
t == ScalarType::UInt64);
return isIntegral || (includeBool && t == ScalarType::Bool);
}
@ -428,6 +459,11 @@ static inline bool isBitsType(ScalarType t) {
t == ScalarType::Bits16;
}
static inline bool isBarebonesUnsignedType(ScalarType t) {
return t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
t == ScalarType::UInt64;
}
static inline ScalarType toQIntType(ScalarType t) {
switch (t) {
case ScalarType::Byte:

View File

@ -670,7 +670,6 @@ inline std::ostream& operator<<(
}
CAFFE_DECLARE_KNOWN_TYPE(std::string, std_string)
CAFFE_DECLARE_KNOWN_TYPE(uint16_t, uint16_t)
CAFFE_DECLARE_KNOWN_TYPE(char, char)
CAFFE_DECLARE_KNOWN_TYPE(std::unique_ptr<std::mutex>, std_unique_ptr_std_mutex)
CAFFE_DECLARE_KNOWN_TYPE(

View File

@ -16,6 +16,12 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
// no "byte" because byte is signed in numpy and we overload
// byte to mean bool often
return std::make_pair("uint8", "");
case at::ScalarType::UInt16:
return std::make_pair("uint16", "");
case at::ScalarType::UInt32:
return std::make_pair("uint32", "");
case at::ScalarType::UInt64:
return std::make_pair("uint64", "");
case at::ScalarType::Char:
// no "char" because it is not consistently signed or unsigned; we want
// to move to int8