Revert "Optimize scatter_add/scatter_reduce in BFloat16/Half data type in CPU backend (#103427)"

This reverts commit da7675621efce341c80187e404ac62cb6c22bbf8.

Reverted https://github.com/pytorch/pytorch/pull/103427 on behalf of https://github.com/clee2000 due to sorry but it looks like this pr broke test_scatter_gather_ops.py::TestScatterGatherCPU::test_scatter_expanded_index_cpu_bfloat16 on periodic parallelnative testing da7675621e https://github.com/pytorch/pytorch/actions/runs/5477783108/jobs/9977608393 ([comment](https://github.com/pytorch/pytorch/pull/103427#issuecomment-1624008753))
This commit is contained in:
PyTorch MergeBot
2023-07-06 17:01:59 +00:00
parent c4cf90aad1
commit f8aedf1efe
4 changed files with 13 additions and 122 deletions

View File

@ -6,7 +6,6 @@
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/ReductionType.h>
#include <c10/util/irange.h>
#include <ATen/OpMathType.h>
namespace at::native {
inline namespace CPU_CAPABILITY {
@ -94,15 +93,6 @@ inline void init(scalar_t* out, int64_t size, bool include_self = false) {
}
}
template <typename scalar_t, ReductionType reduce>
inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
if (!include_self) {
init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
} else {
vec::convert(self_ptr, buffer_ptr, size);
}
}
template <typename scalar_t>
inline scalar_t _max(const scalar_t& x, const scalar_t& y) {
return at::_isnan(y) ? y : std::max(x, y);
@ -125,45 +115,6 @@ inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized
return vec::minimum(x, y);
}
template <typename scalar_t, typename accumut, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map_acc(
const Op& vec_fun,
accumut* output_data,
const accumut* input_data,
const scalar_t* input_data2,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
using aVec = vec::Vectorized<accumut>;
int64_t d = 0;
constexpr int64_t kVecSize = Vec::size();
constexpr int64_t kaVecSize = aVec::size();
for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
Vec data2_vec = Vec::loadu(input_data2 + d);
aVec data2_avec0, data2_avec1;
std::tie(data2_avec0, data2_avec1) = convert_to_float<scalar_t>(data2_vec);
aVec input_vec0 = aVec::loadu(input_data + d);
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
vec_fun(input_vec0, data2_avec0).store(output_data + d);
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
}
if (size - d > 0) {
int64_t tail_size = size - d;
Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
aVec data2_avec0, data2_avec1;
std::tie(data2_avec0, data2_avec1) = convert_to_float<scalar_t>(data2_vec);
if (tail_size > kaVecSize) {
aVec input_vec0 = aVec::loadu(input_data + d);
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
vec_fun(input_vec0, data2_avec0).store(output_data + d);
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
} else {
aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
}
}
}
// for Max and Min, propagate NaN:
template <typename T, ReductionType reduce>
inline T update(const T& x, const T& y) {
@ -191,19 +142,6 @@ inline void update(scalar_t* out, scalar_t* data, int64_t K) {
K);
}
template <typename scalar_t, ReductionType reduce,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void update(at::opmath_type<scalar_t>* out, scalar_t* data, int64_t K) {
using opmath_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<opmath_t>;
map_acc<scalar_t, opmath_t>(
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
out,
out,
data,
K);
}
template <typename scalar_t, ReductionType reduce>
inline void write(scalar_t* out, int64_t count, int64_t K) {
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;

View File

@ -12,14 +12,7 @@
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/irange.h>
#include <ATen/OpMathType.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
namespace {
@ -604,7 +597,6 @@ struct cpu_scatter_gather_base_kernel {
//
// step 2: spmm reduce, parallel on M and vectorize on K
//
template <typename scalar_t, ReductionType reduce>
void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
int64_t* index_data = index.data_ptr<int64_t>();
@ -682,44 +674,21 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
}
});
using opmath_t = at::opmath_type<scalar_t>;
Tensor buffer;
opmath_t* buffer_data = nullptr;
static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
if constexpr (need_acc) {
auto acc_type = at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true);
buffer = at::zeros({num_threads, K}, self.options().dtype(acc_type));
buffer_data = buffer.data_ptr<opmath_t>();
}
// TODO: do blocking on col dimension to reduce WR bandwidth
at::parallel_for(0, num_nonzero_rows, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads,
"expect thread id smaller than ", num_threads, ", got thread id ", tid);
opmath_t* buffer_ptr = nullptr;
for (const auto m : c10::irange(begin, end)) {
int64_t row = row_index[m];
int64_t off_start = row_index_offset[m];
int64_t off_end = row_index_offset[m + 1];
scalar_t* self_ptr = self_data + row * K;
if constexpr (need_acc) {
buffer_ptr = buffer_data + tid * K;
} else {
buffer_ptr = reinterpret_cast<opmath_t*>(self_ptr);
}
// step 1: reinit rows in `self` if needed
_init<scalar_t, reduce>(self_ptr, buffer_ptr, K, include_self);
init<scalar_t, reduce>(self_ptr, K, include_self);
// step 2: reduce
for (const auto n : c10::irange(off_start, off_end)) {
int64_t col = sorted_col_index_values[n];
update<scalar_t, reduce>(buffer_ptr, src_data + col * K, K);
}
if constexpr (need_acc) {
vec::convert(buffer_ptr, self_ptr, K);
update<scalar_t, reduce>(self_ptr, src_data + col * K, K);
}
// step 3: finalize
@ -769,8 +738,8 @@ void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index,
}
void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_add_expanded_index", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, self.scalar_type(), "scatter_add_expanded_index", [&] {
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, /*include_self*/true);
});
}
@ -778,8 +747,8 @@ void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index,
void scatter_reduce_expanded_index_kernel(
const Tensor& self, const Tensor& index, const Tensor& src,
const ReductionType& reduction, bool include_self) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
AT_DISPATCH_REDUCTION_TYPES(reduction, [&]() {
cpu_scatter_reduce_expanded_index<scalar_t, reduce>(self, index, src, include_self);
});

View File

@ -277,16 +277,11 @@ class TestScatterGather(TestCase):
self.assertEqual(input, expected_result)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_scatter_expanded_index(self, device, dtype):
def helper(input_size, idx_size, atol=1e-5, rtol=0.016):
is_reduced_type = dtype in [torch.bfloat16, torch.float16]
if is_reduced_type:
atol = 1e-2
rtol = 1e-2
def helper(input_size, idx_size):
input = torch.randn(input_size, device=device).to(dtype=dtype)
input2 = input.clone().to(torch.float32) if is_reduced_type else input.clone()
input3 = input.clone()
input2 = input.clone()
shape = [1] * len(input_size)
shape[0] = idx_size
@ -305,27 +300,17 @@ class TestScatterGather(TestCase):
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
src = torch.randn(expanded_shape, device=device).to(dtype=dtype)
src2 = src.clone().to(torch.float32) if is_reduced_type else src.clone()
out = input.scatter_add(0, idx, src)
out2 = input2.scatter_add(0, idx2, src2)
out2 = input2.scatter_add(0, idx2, src)
if torch.has_openmp:
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2, atol=atol, rtol=rtol)
else:
out3 = input3.scatter_add(0, idx2, src)
self.assertEqual(out, out3)
self.assertEqual(out, out2)
for reduce in ["sum", "prod", "mean", "amax", "amin"]:
for include_self in [True, False]:
out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self)
out2 = input2.scatter_reduce(0, idx2, src2, reduce=reduce, include_self=include_self)
if torch.has_openmp:
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2,
atol=atol, rtol=rtol)
else:
out3 = input3.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
self.assertEqual(out, out3)
out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
self.assertEqual(out, out2)
helper([50, 17], 100)
helper([50, 1], 100)

View File

@ -2517,7 +2517,6 @@ class TestSparseCSR(TestCase):
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 0.01})
def test_sparse_mm_reduce(self, device, dtype):
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)