diff --git a/aten/src/ATen/native/cpu/PowKernel.cpp b/aten/src/ATen/native/cpu/PowKernel.cpp index 18e14ed5d30d..ed23503099ed 100644 --- a/aten/src/ATen/native/cpu/PowKernel.cpp +++ b/aten/src/ATen/native/cpu/PowKernel.cpp @@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel( } else if (dtype == ScalarType::Half) { [&]() { using scalar_t = - decltype(c10::impl::ScalarTypeToCPPType::t); + c10::impl::ScalarTypeToCPPTypeT; const auto exp = exp_scalar.to(); using Vec = Vectorized; cpu_kernel_vec(iter, diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index ee28c5c1693f..c42d03b9cbf7 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher { out_calc_t output_offset_calculator, loader_t loader, storer_t storer) { - if (ret_t == rt_binary_specializations[arg_index][0] && - arg0_t == rt_binary_specializations[arg_index][1] && - arg1_t == rt_binary_specializations[arg_index][2]) + constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0]; + constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1]; + constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2]; + if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) { + using cret_t = c10::impl::ScalarTypeToCPPTypeT; + using carg0_t = c10::impl::ScalarTypeToCPPTypeT; + using carg1_t = c10::impl::ScalarTypeToCPPTypeT; launch_vectorized_templated_kernel< func_t, array_t, @@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher { out_calc_t, loader_t, storer_t, - decltype(c10::impl::ScalarTypeToCPPType< - rt_binary_specializations[arg_index][0]>::t), - decltype(c10::impl::ScalarTypeToCPPType< - rt_binary_specializations[arg_index][1]>::t), - decltype(c10::impl::ScalarTypeToCPPType< - rt_binary_specializations[arg_index][2]>::t)>( + cret_t, + carg0_t, + carg1_t>( numel, f, data, @@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher { output_offset_calculator, loader, storer); + } } }; diff --git a/torch/headeronly/core/ScalarType.h b/torch/headeronly/core/ScalarType.h index 6caacd8c119e..613c10853d52 100644 --- a/torch/headeronly/core/ScalarType.h +++ b/torch/headeronly/core/ScalarType.h @@ -63,15 +63,15 @@ struct dummy_int1_7_t {}; _(int16_t, Short) \ _(int, Int) \ _(int64_t, Long) \ - _(at::Half, Half) \ + _(c10::Half, Half) \ _(float, Float) \ _(double, Double) \ _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) \ _(bool, Bool) \ - _(at::BFloat16, BFloat16) \ - _(at::Float8_e5m2, Float8_e5m2) \ - _(at::Float8_e4m3fn, Float8_e4m3fn) + _(c10::BFloat16, BFloat16) \ + _(c10::Float8_e5m2, Float8_e5m2) \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) // This macro controls many of our C++ APIs, including constructors // for Scalar as well as the data() and item() accessors on Tensor @@ -81,19 +81,19 @@ struct dummy_int1_7_t {}; _(int16_t, Short) \ _(int, Int) \ _(int64_t, Long) \ - _(at::Half, Half) \ + _(c10::Half, Half) \ _(float, Float) \ _(double, Double) \ _(c10::complex, ComplexHalf) \ _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) \ _(bool, Bool) \ - _(at::BFloat16, BFloat16) \ - _(at::Float8_e5m2, Float8_e5m2) \ - _(at::Float8_e4m3fn, Float8_e4m3fn) \ - _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \ - _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ - _(at::Float8_e8m0fnu, Float8_e8m0fnu) + _(c10::BFloat16, BFloat16) \ + _(c10::Float8_e5m2, Float8_e5m2) \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(c10::Float8_e8m0fnu, Float8_e8m0fnu) // NB: Order matters for this macro; it is relied upon in // _promoteTypesLookup and the serialization format. @@ -103,7 +103,7 @@ struct dummy_int1_7_t {}; _(int16_t, Short) /* 2 */ \ _(int, Int) /* 3 */ \ _(int64_t, Long) /* 4 */ \ - _(at::Half, Half) /* 5 */ \ + _(c10::Half, Half) /* 5 */ \ _(float, Float) /* 6 */ \ _(double, Double) /* 7 */ \ _(c10::complex, ComplexHalf) /* 8 */ \ @@ -113,7 +113,7 @@ struct dummy_int1_7_t {}; _(c10::qint8, QInt8) /* 12 */ \ _(c10::quint8, QUInt8) /* 13 */ \ _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::BFloat16, BFloat16) /* 15 */ \ _(c10::quint4x2, QUInt4x2) /* 16 */ \ _(c10::quint2x4, QUInt2x4) /* 17 */ \ _(c10::bits1x8, Bits1x8) /* 18 */ \ @@ -176,24 +176,19 @@ struct dummy_int1_7_t {}; _(int64_t, Long) \ _(float, Float) \ _(double, Double) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE>::t), \ - SCALARTYPE) + _(c10::impl::ScalarTypeToCPPTypeT, SCALARTYPE) -#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ - _(uint8_t, Byte) \ - _(int8_t, Char) \ - _(int16_t, Short) \ - _(int, Int) \ - _(int64_t, Long) \ - _(float, Float) \ - _(double, Double) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE1>::t), \ - SCALARTYPE1) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE2>::t), \ - SCALARTYPE2) +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(c10::impl::ScalarTypeToCPPTypeT, \ + SCALARTYPE1) \ + _(c10::impl::ScalarTypeToCPPTypeT, SCALARTYPE2) #define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ _(uint8_t, Byte) \ @@ -203,53 +198,41 @@ struct dummy_int1_7_t {}; _(int64_t, Long) \ _(float, Float) \ _(double, Double) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE1>::t), \ + _(c10::impl::ScalarTypeToCPPTypeT, \ SCALARTYPE1) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE2>::t), \ + _(c10::impl::ScalarTypeToCPPTypeT, \ SCALARTYPE2) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE3>::t), \ - SCALARTYPE3) + _(c10::impl::ScalarTypeToCPPTypeT, SCALARTYPE3) -#define AT_FORALL_SCALAR_TYPES_AND7( \ - SCALARTYPE1, \ - SCALARTYPE2, \ - SCALARTYPE3, \ - SCALARTYPE4, \ - SCALARTYPE5, \ - SCALARTYPE6, \ - SCALARTYPE7, \ - _) \ - _(uint8_t, Byte) \ - _(int8_t, Char) \ - _(int16_t, Short) \ - _(int, Int) \ - _(int64_t, Long) \ - _(float, Float) \ - _(double, Double) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE1>::t), \ - SCALARTYPE1) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE2>::t), \ - SCALARTYPE2) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE3>::t), \ - SCALARTYPE3) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE4>::t), \ - SCALARTYPE4) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE5>::t), \ - SCALARTYPE5) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE6>::t), \ - SCALARTYPE6) \ - _(decltype(::c10::impl::ScalarTypeToCPPType< \ - ::c10::ScalarType::SCALARTYPE7>::t), \ - SCALARTYPE7) +#define AT_FORALL_SCALAR_TYPES_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(c10::impl::ScalarTypeToCPPTypeT, \ + SCALARTYPE1) \ + _(c10::impl::ScalarTypeToCPPTypeT, \ + SCALARTYPE2) \ + _(c10::impl::ScalarTypeToCPPTypeT, \ + SCALARTYPE3) \ + _(c10::impl::ScalarTypeToCPPTypeT, \ + SCALARTYPE4) \ + _(c10::impl::ScalarTypeToCPPTypeT, \ + SCALARTYPE5) \ + _(c10::impl::ScalarTypeToCPPTypeT, \ + SCALARTYPE6) \ + _(c10::impl::ScalarTypeToCPPTypeT, SCALARTYPE7) #define AT_FORALL_QINT_TYPES(_) \ _(c10::qint8, QInt8) \ @@ -258,12 +241,12 @@ struct dummy_int1_7_t {}; _(c10::quint4x2, QUInt4x2) \ _(c10::quint2x4, QUInt2x4) -#define AT_FORALL_FLOAT8_TYPES(_) \ - _(at::Float8_e5m2, Float8_e5m2) \ - _(at::Float8_e4m3fn, Float8_e4m3fn) \ - _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \ - _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ - _(at::Float8_e8m0fnu, Float8_e8m0fnu) +#define AT_FORALL_FLOAT8_TYPES(_) \ + _(c10::Float8_e5m2, Float8_e5m2) \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(c10::Float8_e8m0fnu, Float8_e8m0fnu) #define AT_FORALL_COMPLEX_TYPES(_) \ _(c10::complex, ComplexFloat) \ @@ -287,19 +270,10 @@ namespace impl { template struct ScalarTypeToCPPType; -#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ - template <> \ - struct ScalarTypeToCPPType { \ - using type = cpp_type; \ - \ - /* This is a workaround for the CUDA bug which prevents */ \ - /* ::detail::ScalarTypeToCType::type being used directly due to */ \ - /* ambiguous reference which can't to be resolved. For some reason it */ \ - /* can't pick between at::detail and at::cuda::detail. */ \ - /* For repro example, please see: */ \ - /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \ - /* TODO: remove once the bug is fixed. */ \ - static type t; \ +#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ + template <> \ + struct ScalarTypeToCPPType { \ + using type = cpp_type; \ }; AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)