mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Remove workaround to old CUDA bug (#164354)
As in the title. A check for https://github.com/pytorch/pytorch/issues/164348 to see if the workaround can be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164354 Approved by: https://github.com/janeyx99, https://github.com/ngimel, https://github.com/malfet, https://github.com/jeffdaily ghstack dependencies: #164350
This commit is contained in:
committed by
PyTorch MergeBot
parent
48064acf37
commit
26f3803433
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
|
||||
} else if (dtype == ScalarType::Half) {
|
||||
[&]() {
|
||||
using scalar_t =
|
||||
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
|
||||
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
cpu_kernel_vec(iter,
|
||||
|
@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t output_offset_calculator,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
if (ret_t == rt_binary_specializations[arg_index][0] &&
|
||||
arg0_t == rt_binary_specializations[arg_index][1] &&
|
||||
arg1_t == rt_binary_specializations[arg_index][2])
|
||||
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
|
||||
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
|
||||
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
|
||||
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
|
||||
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
|
||||
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
|
||||
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
array_t,
|
||||
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t,
|
||||
loader_t,
|
||||
storer_t,
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][0]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][1]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][2]>::t)>(
|
||||
cret_t,
|
||||
carg0_t,
|
||||
carg1_t>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user