mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Accept non-standard bools in more CUDA kernels
This fixes all remaining CUDA kernels, except those using `cub` or `thrust`, to accept boolean tensors with values other than 1 or 0. I do this by using `c10::load` in more places, and also adding a `load_vector` helper into `MemoryAccess.cuh` that does the same thing for vectorized loads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78957 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
4945c72151
commit
cd9e158007
@ -16,6 +16,7 @@ template <>
|
||||
struct LoadImpl<bool> {
|
||||
C10_HOST_DEVICE static bool apply(const void* src) {
|
||||
static_assert(sizeof(bool) == sizeof(char), "");
|
||||
// NOTE: [Loading boolean values]
|
||||
// Protect against invalid boolean values by loading as a byte
|
||||
// first, then converting to bool (see gh-54789).
|
||||
return *reinterpret_cast<const unsigned char*>(src);
|
||||
@ -29,4 +30,9 @@ C10_HOST_DEVICE T load(const void* src) {
|
||||
return c10::detail::LoadImpl<T>::apply(src);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
|
||||
return c10::detail::LoadImpl<scalar_t>::apply(src);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
Reference in New Issue
Block a user