[SDPA][EZ] Abate narrowing conversion warning spam in flash_api.cpp (#153643)

for messages like
```/workspace/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp:1396:38: warning: narrowing conversion of ‘(char)(& q)->at::Tensor::<anonymous>.at::TensorBase::get_device()’ from ‘char’ to ‘c10::DeviceIndex’ {aka ‘signed ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153643
Approved by: https://github.com/Skylion007
This commit is contained in:
eqy
2025-05-17 02:07:32 +00:00
committed by PyTorch MergeBot
parent aac30ef503
commit e802b29ed4

View File

@ -479,7 +479,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
@ -705,7 +705,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
@ -940,7 +940,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
@ -1163,7 +1163,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
auto softmax_d = at::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
@ -1393,7 +1393,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();