mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix torch.prod vectorized path for bool (#128009)
Fix https://github.com/pytorch/pytorch/issues/127866. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128009 Approved by: https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
89929d9abc
commit
2ba60a1618
@ -314,6 +314,17 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
|
||||
return flip8(v);
|
||||
}
|
||||
|
||||
inline Vectorized<bool> operator&&(
|
||||
const Vectorized<bool>& self,
|
||||
const Vectorized<bool>& other) {
|
||||
const __m256i* self_ = reinterpret_cast<const __m256i*>(self.as_bytes());
|
||||
const __m256i* other_ = reinterpret_cast<const __m256i*>(other.as_bytes());
|
||||
__m256i out = _mm256_and_si256(*self_, *other_);
|
||||
Vectorized<bool> ret;
|
||||
std::memcpy(ret, &out, ret.size() * sizeof(bool));
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif // (defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
}} // namepsace at::vec::CPU_CAPABILITY
|
||||
|
@ -274,6 +274,18 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
|
||||
return flip8(v);
|
||||
}
|
||||
|
||||
inline Vectorized<bool> operator&&(
|
||||
const Vectorized<bool>& self,
|
||||
const Vectorized<bool>& other) {
|
||||
const __m512i* self_ = reinterpret_cast<const __m512i*>(self.as_bytes());
|
||||
const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
|
||||
__m512i out = _mm512_and_si512(*self_, *other_);
|
||||
Vectorized<bool> ret;
|
||||
// We do not have a constructer that takes __m512i, so we need to memcpy
|
||||
std::memcpy(ret, &out, ret.size() * sizeof(bool));
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
}}}
|
||||
|
@ -947,6 +947,17 @@ inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const
|
||||
return a * b - c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Vectorized<T> inline operator&&(
|
||||
const Vectorized<T>& a,
|
||||
const Vectorized<T>& b) {
|
||||
Vectorized<T> ret;
|
||||
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
||||
ret[i] = a[i] && b[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <int64_t scale = 1, typename T = void>
|
||||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
|
||||
inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
|
||||
|
@ -1498,7 +1498,13 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(res1, res2.to(dtype=dtype))
|
||||
|
||||
def test_prod_bool(self, device):
|
||||
vals = [[True, True], [True, False], [False, False], []]
|
||||
vals = [
|
||||
[True, True],
|
||||
[True, False],
|
||||
[False, False],
|
||||
[],
|
||||
[False] * 256, # https://github.com/pytorch/pytorch/issues/127866
|
||||
]
|
||||
for val in vals:
|
||||
result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
|
||||
expect = np.prod(np.array(val), dtype=bool)
|
||||
|
Reference in New Issue
Block a user