mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add unsigned integer dtypes to PyTorch (#116594)
The dtypes are very useless right now (not even fill works), but it makes torch.uint16, uint32 and uint64 available as a dtype. Towards https://github.com/pytorch/pytorch/issues/58734 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/116594 Approved by: https://github.com/albanD ghstack dependencies: #116698, #116693
This commit is contained in:
committed by
PyTorch MergeBot
parent
8e273e23b5
commit
8ddac14a15
@ -10,6 +10,9 @@ DLDataType getDLDataType(const Tensor& t) {
|
||||
dtype.bits = t.element_size() * 8;
|
||||
switch (t.scalar_type()) {
|
||||
case ScalarType::Byte:
|
||||
case ScalarType::UInt16:
|
||||
case ScalarType::UInt32:
|
||||
case ScalarType::UInt64:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case ScalarType::Char:
|
||||
|
@ -74,6 +74,10 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||
toString(b));
|
||||
}
|
||||
|
||||
if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
|
||||
return ScalarType::Undefined;
|
||||
}
|
||||
|
||||
auto ix_a = dtype2index[static_cast<int64_t>(a)];
|
||||
TORCH_INTERNAL_ASSERT(ix_a != -1);
|
||||
auto ix_b = dtype2index[static_cast<int64_t>(b)];
|
||||
|
@ -26,11 +26,24 @@
|
||||
namespace c10 {
|
||||
|
||||
// For the macros below:
|
||||
// NB: If you want to macro some code for all non-QInt scalar types (i.e. types
|
||||
// with complete information, you probably want one of the
|
||||
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
|
||||
// macros below, which are designed to behave similarly to the Dispatch macros
|
||||
// with the same name.
|
||||
//
|
||||
// For users: If you want to macro some code for all non-QInt scalar types
|
||||
// (i.e. types with complete information, you probably want one of the
|
||||
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
|
||||
// designed to behave similarly to the Dispatch macros with the same name.
|
||||
//
|
||||
// For adding a new dtype: In the beginning, we had an idea that there was a
|
||||
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
|
||||
// iterate over them. But over the years we added weird types which couldn't
|
||||
// be handled uniformly everywhere and so in the end we ended up with some
|
||||
// mish-mosh of some helper macros, but mostly use sites making a call about
|
||||
// what dtypes they can or can't support. So if you want to add a new dtype,
|
||||
// the preferred resolution is to find a dtype similar to what you want,
|
||||
// grep for it and edit all the sites you find this way. If you need to add
|
||||
// a completely new kind of dtype, you're going to have to laboriously audit
|
||||
// all of the sites everywhere to figure out how it should work. Consulting
|
||||
// some old PRs where we added new dtypes (check history of this file) can
|
||||
// help give you an idea where to start.
|
||||
|
||||
// NB: Order matters for this macro; it is relied upon in
|
||||
// _promoteTypesLookup and the serialization format.
|
||||
@ -61,11 +74,18 @@ namespace c10 {
|
||||
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
|
||||
_(uint16_t, UInt16) /* 27 */ \
|
||||
_(uint32_t, UInt32) /* 28 */ \
|
||||
_(uint64_t, UInt64) /* 29 */
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
//
|
||||
// TODO: To add unsigned int types here, we must define accumulate type.
|
||||
// But uint8 currently accumulates into int64, so we would have to make
|
||||
// an inconsistent choice for the larger types. Difficult.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
@ -82,6 +102,8 @@ namespace c10 {
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
// This macro controls many of our C++ APIs, including constructors
|
||||
// for Scalar as well as the data() and item() accessors on Tensor
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
@ -157,6 +179,8 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
// NB: despite its generic sounding name, the macros that don't take _AND
|
||||
// are mostly only used by tensorexpr
|
||||
#define AT_FORALL_INT_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
@ -173,6 +197,11 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
_(float, Float) \
|
||||
_(double, Double)
|
||||
|
||||
// These macros are often controlling how many template instantiations we
|
||||
// create for kernels. It is typically inappropriate to add new dtypes here,
|
||||
// instead, new types should be added to use sites on a case-by-case basis.
|
||||
// We generally are not accepting new dtypes due to binary size concerns.
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
@ -384,7 +413,9 @@ static inline size_t elementSize(ScalarType t) {
|
||||
static inline bool isIntegralType(ScalarType t, bool includeBool) {
|
||||
bool isIntegral =
|
||||
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
|
||||
t == ScalarType::Long || t == ScalarType::Short);
|
||||
t == ScalarType::Long || t == ScalarType::Short ||
|
||||
t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
|
||||
t == ScalarType::UInt64);
|
||||
|
||||
return isIntegral || (includeBool && t == ScalarType::Bool);
|
||||
}
|
||||
@ -428,6 +459,11 @@ static inline bool isBitsType(ScalarType t) {
|
||||
t == ScalarType::Bits16;
|
||||
}
|
||||
|
||||
static inline bool isBarebonesUnsignedType(ScalarType t) {
|
||||
return t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
|
||||
t == ScalarType::UInt64;
|
||||
}
|
||||
|
||||
static inline ScalarType toQIntType(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::Byte:
|
||||
|
@ -670,7 +670,6 @@ inline std::ostream& operator<<(
|
||||
}
|
||||
|
||||
CAFFE_DECLARE_KNOWN_TYPE(std::string, std_string)
|
||||
CAFFE_DECLARE_KNOWN_TYPE(uint16_t, uint16_t)
|
||||
CAFFE_DECLARE_KNOWN_TYPE(char, char)
|
||||
CAFFE_DECLARE_KNOWN_TYPE(std::unique_ptr<std::mutex>, std_unique_ptr_std_mutex)
|
||||
CAFFE_DECLARE_KNOWN_TYPE(
|
||||
|
@ -16,6 +16,12 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
|
||||
// no "byte" because byte is signed in numpy and we overload
|
||||
// byte to mean bool often
|
||||
return std::make_pair("uint8", "");
|
||||
case at::ScalarType::UInt16:
|
||||
return std::make_pair("uint16", "");
|
||||
case at::ScalarType::UInt32:
|
||||
return std::make_pair("uint32", "");
|
||||
case at::ScalarType::UInt64:
|
||||
return std::make_pair("uint64", "");
|
||||
case at::ScalarType::Char:
|
||||
// no "char" because it is not consistently signed or unsigned; we want
|
||||
// to move to int8
|
||||
|
Reference in New Issue
Block a user