fix torch.prod vectorized path for bool (#128009)

Fix https://github.com/pytorch/pytorch/issues/127866.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128009
Approved by: https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
haozhe.zhu
2024-08-27 09:29:50 +08:00
committed by PyTorch MergeBot
parent 89929d9abc
commit 2ba60a1618
4 changed files with 41 additions and 1 deletions

View File

@ -314,6 +314,17 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}
inline Vectorized<bool> operator&&(
const Vectorized<bool>& self,
const Vectorized<bool>& other) {
const __m256i* self_ = reinterpret_cast<const __m256i*>(self.as_bytes());
const __m256i* other_ = reinterpret_cast<const __m256i*>(other.as_bytes());
__m256i out = _mm256_and_si256(*self_, *other_);
Vectorized<bool> ret;
std::memcpy(ret, &out, ret.size() * sizeof(bool));
return ret;
}
#endif // (defined(CPU_CAPABILITY_AVX2)
}} // namepsace at::vec::CPU_CAPABILITY

View File

@ -274,6 +274,18 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}
inline Vectorized<bool> operator&&(
const Vectorized<bool>& self,
const Vectorized<bool>& other) {
const __m512i* self_ = reinterpret_cast<const __m512i*>(self.as_bytes());
const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
__m512i out = _mm512_and_si512(*self_, *other_);
Vectorized<bool> ret;
// We do not have a constructer that takes __m512i, so we need to memcpy
std::memcpy(ret, &out, ret.size() * sizeof(bool));
return ret;
}
#endif // defined(CPU_CAPABILITY_AVX512)
}}}

View File

@ -947,6 +947,17 @@ inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const
return a * b - c;
}
template <typename T>
Vectorized<T> inline operator&&(
const Vectorized<T>& a,
const Vectorized<T>& b) {
Vectorized<T> ret;
for (int i = 0; i != Vectorized<T>::size(); i++) {
ret[i] = a[i] && b[i];
}
return ret;
}
template <int64_t scale = 1, typename T = void>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {

View File

@ -1498,7 +1498,13 @@ class TestReductions(TestCase):
self.assertEqual(res1, res2.to(dtype=dtype))
def test_prod_bool(self, device):
vals = [[True, True], [True, False], [False, False], []]
vals = [
[True, True],
[True, False],
[False, False],
[],
[False] * 256, # https://github.com/pytorch/pytorch/issues/127866
]
for val in vals:
result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
expect = np.prod(np.array(val), dtype=bool)