mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support unsigned int for randint, item, equality, fill, iinfo, tensor (#116805)
These are some basic utilities that are often used for testing. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/116805 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
4a10e9eed4
commit
2e983fcfd3
@ -85,14 +85,18 @@
|
||||
// to be interpreted as being multiple arguments
|
||||
#define AT_WRAP(...) __VA_ARGS__
|
||||
|
||||
#define AT_FLOAT8_TYPES \
|
||||
kFloat8_e5m2, kFloat8_e5m2fnuz, kFloat8_e4m3fn, kFloat8_e4m3fnuz
|
||||
#define AT_FLOAT8_TYPES \
|
||||
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||
c10::kFloat8_e4m3fnuz
|
||||
|
||||
#define AT_INTEGRAL_TYPES kByte, kChar, kInt, kLong, kShort
|
||||
#define AT_FLOATING_TYPES kDouble, kFloat
|
||||
#define AT_BAREBONES_UNSIGNED_TYPES kUInt16, kUInt32, kUInt64
|
||||
#define AT_COMPLEX_TYPES kComplexDouble, kComplexFloat
|
||||
#define AT_QINT_TYPES kQInt8, kQUInt8, kQInt32
|
||||
#define AT_INTEGRAL_TYPES \
|
||||
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
||||
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
||||
#define AT_INTEGRAL_TYPES_V2 \
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
||||
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
||||
// NB: not *actually* all types
|
||||
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/Generator.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Tensor.h>
|
||||
@ -110,13 +111,21 @@ static void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMet
|
||||
WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
|
||||
WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
|
||||
});
|
||||
} else if (scalar_type == kUInt64) {
|
||||
// When you do a comparison between int64_t and uint64_t, the usual
|
||||
// arithmetic conversions say that the int64_t value is promoted to
|
||||
// unsigned. But this conversion wraps around: if I had -1 as my int64_t,
|
||||
// then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
|
||||
// the right thing to do.
|
||||
CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
|
||||
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
|
||||
} else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
|
||||
AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, scalar_type, "check_random_integral_bounds", [&]() {
|
||||
AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
|
||||
const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
|
||||
const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
||||
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
|
||||
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
|
||||
});
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
|
||||
} else {
|
||||
TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
|
||||
}
|
||||
@ -152,13 +161,13 @@ at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, c10::optional<in
|
||||
TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
|
||||
});
|
||||
} else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
|
||||
AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "random_from_to_range_calc", [&] {
|
||||
AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
|
||||
if constexpr (std::is_same_v<scalar_t, bool>) {
|
||||
to_inc = static_cast<int64_t>(true);
|
||||
} else {
|
||||
to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
||||
}
|
||||
});
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
|
||||
} else {
|
||||
TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
|
||||
}
|
||||
|
@ -104,6 +104,7 @@ static inline void check_scalar_type_device_layout_equal(const Tensor& out, cons
|
||||
|
||||
static inline Tensor integer_upcast(const Tensor& self, c10::optional<ScalarType> dtype) {
|
||||
ScalarType scalarType = self.scalar_type();
|
||||
TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
|
||||
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
|
||||
return self.toType(upcast_scalarType);
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ Scalar item(const Tensor& self) {
|
||||
}
|
||||
}
|
||||
|
||||
#define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16
|
||||
#define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
#if !defined(C10_MOBILE)
|
||||
#define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES), AT_EXPAND(AT_FLOAT8_TYPES)
|
||||
#else
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
@ -81,7 +82,10 @@ void atan2_kernel(TensorIteratorBase& iter) {
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
|
||||
AT_DISPATCH_V2( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(__VA_ARGS__), \
|
||||
kComplexHalf, \
|
||||
kHalf, \
|
||||
kBool, \
|
||||
@ -89,23 +93,21 @@ void atan2_kernel(TensorIteratorBase& iter) {
|
||||
kFloat8_e5m2, \
|
||||
kFloat8_e5m2fnuz, \
|
||||
kFloat8_e4m3fn, \
|
||||
kFloat8_e4m3fnuz, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
__VA_ARGS__)
|
||||
kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
|
||||
AT_DISPATCH_V2( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(__VA_ARGS__), \
|
||||
kComplexHalf, \
|
||||
kHalf, \
|
||||
kBFloat16, \
|
||||
kFloat8_e5m2, \
|
||||
kFloat8_e4m3fn, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
__VA_ARGS__)
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, TYPE, NAME, __VA_ARGS__)
|
||||
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#else
|
||||
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/CPUApplyUtils.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/ExpandBase.h>
|
||||
#include <ATen/core/DistributionsHelper.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
@ -25,13 +26,13 @@ namespace {
|
||||
|
||||
template<typename RNG>
|
||||
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cpu", [&] {
|
||||
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
|
||||
std::lock_guard<std::mutex> lock(generator->mutex_);
|
||||
cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
|
||||
uniform_int_from_to_distribution<scalar_t> random(range, base);
|
||||
return random(generator);
|
||||
});
|
||||
});
|
||||
}), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
|
||||
// This is the special kernel to handle single specific case:
|
||||
|
@ -52,7 +52,7 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
|
||||
[=]() -> scalar_t { return value; },
|
||||
[=]() { return Vectorized<scalar_t>(value); });
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kBool
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kBool, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/NativeFunctions.h>
|
||||
@ -14,13 +14,13 @@ namespace at::native {
|
||||
|
||||
Scalar _local_scalar_dense_cuda(const Tensor& self) {
|
||||
Scalar r;
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
|
||||
AT_DISPATCH_V2(
|
||||
self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] {
|
||||
scalar_t value;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
at::cuda::memcpy_and_sync(&value, self.const_data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
|
||||
r = Scalar(value);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
return r;
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
@ -29,11 +30,10 @@ struct CompareEqFunctor{
|
||||
}
|
||||
|
||||
C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2,
|
||||
iter.common_dtype(), "compare_eq_ne_cuda", [&]() {
|
||||
AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() {
|
||||
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
|
||||
iter, CompareEqFunctor<scalar_t>(op));
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
|
||||
void eq_kernel_cuda(TensorIteratorBase& iter) {
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/cuda/CachingHostAllocator.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
@ -98,6 +99,8 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: We probably can use the opaque type trick to avoid creating duplicate
|
||||
// kernels for equivalent bit lengths
|
||||
void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
ScalarType dtype = iter.dtype(0);
|
||||
if (isQIntType(dtype)) {
|
||||
@ -115,10 +118,10 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
});
|
||||
#endif /* !defined(USE_ROCM) */
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kHalf, kBool, kBFloat16, kComplexHalf,dtype, "copy_", [&] {
|
||||
AT_DISPATCH_V2(
|
||||
dtype, "copy_", AT_WRAP([&] {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/ExpandBase.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
@ -285,7 +286,7 @@ namespace cuda {
|
||||
|
||||
template<typename RNG>
|
||||
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cuda", [&] {
|
||||
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
|
||||
if ((
|
||||
std::is_same<scalar_t, int64_t>::value ||
|
||||
std::is_same<scalar_t, double>::value ||
|
||||
@ -317,7 +318,7 @@ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t bas
|
||||
},
|
||||
random_func);
|
||||
}
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
|
||||
// This is the special kernel to handle single specific case:
|
||||
|
@ -1,5 +1,6 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
@ -19,9 +20,9 @@ struct FillFunctor {
|
||||
};
|
||||
|
||||
void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, iter.dtype(), "fill_cuda", [&]() {
|
||||
AT_DISPATCH_V2(iter.dtype(), "fill_cuda", AT_WRAP([&]() {
|
||||
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda);
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
@ -1481,14 +1482,16 @@ Tensor& index_select_out_cuda(
|
||||
index_select_out_cuda_impl<scalar_t>(out, self, dim, index);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf,
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::BFloat16,
|
||||
AT_DISPATCH_V2(
|
||||
out.scalar_type(),
|
||||
"index_select_cuda",
|
||||
[&] { index_select_out_cuda_impl<scalar_t>(out, self, dim, index); });
|
||||
AT_WRAP([&] { index_select_out_cuda_impl<scalar_t>(out, self, dim, index); }),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16
|
||||
);
|
||||
}
|
||||
|
||||
return out;
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
@ -431,12 +432,10 @@ TORCH_IMPL_FUNC(cat_out_cuda)
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kComplexHalf, kHalf, kBool, kBFloat16,
|
||||
result.scalar_type(), "cat_cuda", [&]() {
|
||||
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
} else if (materialized.size() > 1 &&
|
||||
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
|
||||
@ -451,12 +450,10 @@ TORCH_IMPL_FUNC(cat_out_cuda)
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kComplexHalf, kHalf, kBool, kBFloat16,
|
||||
result.scalar_type(), "cat_cuda", [&]() {
|
||||
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
} else {
|
||||
int64_t offset = 0;
|
||||
|
@ -71,6 +71,9 @@ C10_HOST_DEVICE inline dest_t fetch_and_cast(
|
||||
const void* ptr) {
|
||||
switch (src_type) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE)
|
||||
FETCH_AND_CAST_CASE(uint16_t, UInt16)
|
||||
FETCH_AND_CAST_CASE(uint32_t, UInt32)
|
||||
FETCH_AND_CAST_CASE(uint64_t, UInt64)
|
||||
default:
|
||||
ERROR_UNSUPPORTED_CAST
|
||||
}
|
||||
@ -90,6 +93,9 @@ C10_HOST_DEVICE inline void cast_and_store(
|
||||
src_t value) {
|
||||
switch (dest_type) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE)
|
||||
CAST_AND_STORE_CASE(uint16_t, UInt16)
|
||||
CAST_AND_STORE_CASE(uint32_t, UInt32)
|
||||
CAST_AND_STORE_CASE(uint64_t, UInt64)
|
||||
default:;
|
||||
}
|
||||
ERROR_UNSUPPORTED_CAST
|
||||
|
@ -75,7 +75,21 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||
}
|
||||
|
||||
if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
|
||||
return ScalarType::Undefined;
|
||||
// There are two problems with promotion here:
|
||||
//
|
||||
// - Our promotion rule for uint8 is inconsistent with Numpy; Numpy
|
||||
// promotes to uint64, but since we never had uint64 for the longest
|
||||
// time, we promote to int64. Changing this is BC-breaking
|
||||
//
|
||||
// - We must not promote uint64 to int64 because this will overflow.
|
||||
//
|
||||
// It'll be a bit of work to fix it, so we're punting on it for now.
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Promotion for uint16, uint32, uint64 types is not supported, attempted to promote ",
|
||||
toString(a),
|
||||
" and ",
|
||||
toString(b));
|
||||
}
|
||||
|
||||
auto ix_a = dtype2index[static_cast<int64_t>(a)];
|
||||
|
@ -149,7 +149,7 @@ class TestTorchDeviceType(TestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
|
||||
torch.bool, torch.float32, torch.complex64, torch.float64,
|
||||
torch.complex128)
|
||||
torch.complex128, torch.uint16, torch.uint32, torch.uint64)
|
||||
def test_bytes_to_scalar(self, device, dtype):
|
||||
def rand_byte():
|
||||
if dtype == torch.bool:
|
||||
@ -166,7 +166,7 @@ class TestTorchDeviceType(TestCase):
|
||||
|
||||
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
|
||||
torch.bool, torch.float32, torch.complex64, torch.float64,
|
||||
torch.complex128)
|
||||
torch.complex128, torch.uint16, torch.uint32, torch.uint64)
|
||||
def test_storage(self, device, dtype):
|
||||
v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9)
|
||||
self.assertEqual(v.storage()[0], v[0][0])
|
||||
|
@ -1296,6 +1296,9 @@ def gen_pyi(
|
||||
"float8_e5m2fnuz",
|
||||
"half",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"short",
|
||||
|
@ -8,6 +8,8 @@
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <structmember.h>
|
||||
@ -150,11 +152,19 @@ static PyObject* THPFInfo_min(THPFInfo* self, void*) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_IINFO_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_V2( \
|
||||
TYPE, NAME, AT_WRAP(__VA_ARGS__), AT_EXPAND(AT_INTEGRAL_TYPES_V2))
|
||||
|
||||
static PyObject* THPIInfo_max(THPIInfo* self, void*) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
|
||||
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] {
|
||||
return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
|
||||
return AT_DISPATCH_IINFO_TYPES(self->type, "max", [] {
|
||||
if (std::is_unsigned_v<scalar_t>) {
|
||||
return THPUtils_packUInt64(std::numeric_limits<scalar_t>::max());
|
||||
} else {
|
||||
return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
|
||||
}
|
||||
});
|
||||
}
|
||||
// Quantized Type
|
||||
@ -167,8 +177,12 @@ static PyObject* THPIInfo_max(THPIInfo* self, void*) {
|
||||
static PyObject* THPIInfo_min(THPIInfo* self, void*) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
|
||||
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] {
|
||||
return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
|
||||
return AT_DISPATCH_IINFO_TYPES(self->type, "min", [] {
|
||||
if (std::is_unsigned_v<scalar_t>) {
|
||||
return THPUtils_packUInt64(std::numeric_limits<scalar_t>::lowest());
|
||||
} else {
|
||||
return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
|
||||
}
|
||||
});
|
||||
}
|
||||
// Quantized Type
|
||||
@ -181,7 +195,7 @@ static PyObject* THPIInfo_min(THPIInfo* self, void*) {
|
||||
static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto primary_name = torch::utils::getDtypeNames(self->type).first;
|
||||
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "dtype", [&primary_name] {
|
||||
return AT_DISPATCH_IINFO_TYPES(self->type, "dtype", [&primary_name] {
|
||||
return PyUnicode_FromString(primary_name.data());
|
||||
});
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
@ -36,7 +36,8 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
|
||||
*(uint32_t*)data = unpackIntegral<uint32_t>(obj, "uint32");
|
||||
break;
|
||||
case at::kUInt64:
|
||||
*(uint64_t*)data = unpackIntegral<uint64_t>(obj, "uint64");
|
||||
// NB: This doesn't allow implicit conversion of float to int
|
||||
*(uint64_t*)data = THPUtils_unpackUInt64(obj);
|
||||
break;
|
||||
case at::kChar:
|
||||
*(int8_t*)data = unpackIntegral<int8_t>(obj, "int8");
|
||||
|
@ -9,7 +9,16 @@ from typing import cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
_INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||
_INTEGRAL_TYPES = [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.uint16,
|
||||
torch.uint32,
|
||||
torch.uint64,
|
||||
]
|
||||
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
|
||||
_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2]
|
||||
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
|
||||
|
@ -4506,6 +4506,9 @@ def bytes_to_scalar(byte_list: List[int], dtype: torch.dtype, device: torch.devi
|
||||
dtype_to_ctype: Dict[torch.dtype, Any] = {
|
||||
torch.int8: ctypes.c_int8,
|
||||
torch.uint8: ctypes.c_uint8,
|
||||
torch.uint16: ctypes.c_uint16,
|
||||
torch.uint32: ctypes.c_uint32,
|
||||
torch.uint64: ctypes.c_uint64,
|
||||
torch.int16: ctypes.c_int16,
|
||||
torch.int32: ctypes.c_int32,
|
||||
torch.int64: ctypes.c_int64,
|
||||
|
Reference in New Issue
Block a user