Remove workaround to old CUDA bug (#164354)

As in the title.

A check for https://github.com/pytorch/pytorch/issues/164348 to see if the workaround can be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164354
Approved by: https://github.com/janeyx99, https://github.com/ngimel, https://github.com/malfet, https://github.com/jeffdaily
ghstack dependencies: #164350
This commit is contained in:
Pearu Peterson
2025-10-15 14:54:14 +03:00
committed by PyTorch MergeBot
parent 48064acf37
commit 26f3803433
3 changed files with 79 additions and 103 deletions

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
cret_t,
carg0_t,
carg1_t>(
numel,
f,
data,
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};