Remove workaround to old CUDA bug (#164354)

As in the title.

A check for https://github.com/pytorch/pytorch/issues/164348 to see if the workaround can be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164354
Approved by: https://github.com/janeyx99, https://github.com/ngimel, https://github.com/malfet, https://github.com/jeffdaily
ghstack dependencies: #164350
This commit is contained in:
Pearu Peterson
2025-10-15 14:54:14 +03:00
committed by PyTorch MergeBot
parent 48064acf37
commit 26f3803433
3 changed files with 79 additions and 103 deletions

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
cret_t,
carg0_t,
carg1_t>(
numel,
f,
data,
@ -880,6 +881,7 @@ struct type_specialized_kernel_launcher {
loader,
storer);
}
}
};
} // namespace

View File

@ -63,15 +63,15 @@ struct dummy_int1_7_t {};
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(c10::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
_(c10::BFloat16, BFloat16) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn)
// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
@ -81,19 +81,19 @@ struct dummy_int1_7_t {};
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(c10::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<c10::Half>, ComplexHalf) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
_(c10::BFloat16, BFloat16) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
@ -103,7 +103,7 @@ struct dummy_int1_7_t {};
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(c10::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
@ -113,7 +113,7 @@ struct dummy_int1_7_t {};
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
@ -176,9 +176,7 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE>::t), \
SCALARTYPE)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE>, SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
@ -188,12 +186,9 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
@ -203,15 +198,11 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
@ -229,27 +220,19 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE4>, \
SCALARTYPE4) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE5>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE5>, \
SCALARTYPE5) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE6>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE6>, \
SCALARTYPE6) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE7>::t), \
SCALARTYPE7)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE7>, SCALARTYPE7)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
@ -259,11 +242,11 @@ struct dummy_int1_7_t {};
_(c10::quint2x4, QUInt2x4)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
@ -291,15 +274,6 @@ struct ScalarTypeToCPPType;
template <> \
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
using type = cpp_type; \
\
/* This is a workaround for the CUDA bug which prevents */ \
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
/* ambiguous reference which can't to be resolved. For some reason it */ \
/* can't pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
static type t; \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)