mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize scatter_add/scatter_reduce in BFloat16/Half data type in CPU backend (#103427)
### Description This PR is to optimize scatter_add/scatter_reduce of BFloat16/Half data type in CPU backend, which is one task in https://github.com/pyg-team/pytorch_geometric/issues/7057. Main point is creating a buffer among threads to accumulate intermediate data as fp32 data type. Next step: - [x] Add benchmarks - [x] Extend to Half - [x] Simplify code ### Performance test (Updated) Test BFloat16 in Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz With jemalloc and iomp Single socket (40C)  Single core  Pull Request resolved: https://github.com/pytorch/pytorch/pull/103427 Approved by: https://github.com/mingfeima, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf127d236a
commit
da7675621e
@ -6,6 +6,7 @@
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/native/ReductionType.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
namespace at::native {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
@ -93,6 +94,15 @@ inline void init(scalar_t* out, int64_t size, bool include_self = false) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, ReductionType reduce>
|
||||
inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
|
||||
if (!include_self) {
|
||||
init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
|
||||
} else {
|
||||
vec::convert(self_ptr, buffer_ptr, size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline scalar_t _max(const scalar_t& x, const scalar_t& y) {
|
||||
return at::_isnan(y) ? y : std::max(x, y);
|
||||
@ -115,6 +125,45 @@ inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized
|
||||
return vec::minimum(x, y);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accumut, typename Op,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline void map_acc(
|
||||
const Op& vec_fun,
|
||||
accumut* output_data,
|
||||
const accumut* input_data,
|
||||
const scalar_t* input_data2,
|
||||
int64_t size) {
|
||||
using Vec = vec::Vectorized<scalar_t>;
|
||||
using aVec = vec::Vectorized<accumut>;
|
||||
int64_t d = 0;
|
||||
constexpr int64_t kVecSize = Vec::size();
|
||||
constexpr int64_t kaVecSize = aVec::size();
|
||||
for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
|
||||
Vec data2_vec = Vec::loadu(input_data2 + d);
|
||||
aVec data2_avec0, data2_avec1;
|
||||
std::tie(data2_avec0, data2_avec1) = convert_to_float<scalar_t>(data2_vec);
|
||||
aVec input_vec0 = aVec::loadu(input_data + d);
|
||||
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
|
||||
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
||||
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
|
||||
}
|
||||
if (size - d > 0) {
|
||||
int64_t tail_size = size - d;
|
||||
Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
|
||||
aVec data2_avec0, data2_avec1;
|
||||
std::tie(data2_avec0, data2_avec1) = convert_to_float<scalar_t>(data2_vec);
|
||||
if (tail_size > kaVecSize) {
|
||||
aVec input_vec0 = aVec::loadu(input_data + d);
|
||||
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
|
||||
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
||||
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
|
||||
} else {
|
||||
aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
|
||||
vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for Max and Min, propagate NaN:
|
||||
template <typename T, ReductionType reduce>
|
||||
inline T update(const T& x, const T& y) {
|
||||
@ -142,6 +191,19 @@ inline void update(scalar_t* out, scalar_t* data, int64_t K) {
|
||||
K);
|
||||
}
|
||||
|
||||
template <typename scalar_t, ReductionType reduce,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline void update(at::opmath_type<scalar_t>* out, scalar_t* data, int64_t K) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
using Vec = vec::Vectorized<opmath_t>;
|
||||
map_acc<scalar_t, opmath_t>(
|
||||
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
|
||||
out,
|
||||
out,
|
||||
data,
|
||||
K);
|
||||
}
|
||||
|
||||
template <typename scalar_t, ReductionType reduce>
|
||||
inline void write(scalar_t* out, int64_t count, int64_t K) {
|
||||
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
|
||||
|
@ -12,7 +12,14 @@
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
@ -597,6 +604,7 @@ struct cpu_scatter_gather_base_kernel {
|
||||
//
|
||||
// step 2: spmm reduce, parallel on M and vectorize on K
|
||||
//
|
||||
|
||||
template <typename scalar_t, ReductionType reduce>
|
||||
void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
|
||||
int64_t* index_data = index.data_ptr<int64_t>();
|
||||
@ -674,21 +682,44 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
|
||||
}
|
||||
});
|
||||
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
Tensor buffer;
|
||||
opmath_t* buffer_data = nullptr;
|
||||
static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
|
||||
if constexpr (need_acc) {
|
||||
auto acc_type = at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true);
|
||||
buffer = at::zeros({num_threads, K}, self.options().dtype(acc_type));
|
||||
buffer_data = buffer.data_ptr<opmath_t>();
|
||||
}
|
||||
|
||||
// TODO: do blocking on col dimension to reduce WR bandwidth
|
||||
at::parallel_for(0, num_nonzero_rows, 1, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
TORCH_CHECK(tid < num_threads,
|
||||
"expect thread id smaller than ", num_threads, ", got thread id ", tid);
|
||||
opmath_t* buffer_ptr = nullptr;
|
||||
|
||||
for (const auto m : c10::irange(begin, end)) {
|
||||
int64_t row = row_index[m];
|
||||
int64_t off_start = row_index_offset[m];
|
||||
int64_t off_end = row_index_offset[m + 1];
|
||||
scalar_t* self_ptr = self_data + row * K;
|
||||
if constexpr (need_acc) {
|
||||
buffer_ptr = buffer_data + tid * K;
|
||||
} else {
|
||||
buffer_ptr = reinterpret_cast<opmath_t*>(self_ptr);
|
||||
}
|
||||
|
||||
// step 1: reinit rows in `self` if needed
|
||||
init<scalar_t, reduce>(self_ptr, K, include_self);
|
||||
_init<scalar_t, reduce>(self_ptr, buffer_ptr, K, include_self);
|
||||
|
||||
// step 2: reduce
|
||||
for (const auto n : c10::irange(off_start, off_end)) {
|
||||
int64_t col = sorted_col_index_values[n];
|
||||
update<scalar_t, reduce>(self_ptr, src_data + col * K, K);
|
||||
update<scalar_t, reduce>(buffer_ptr, src_data + col * K, K);
|
||||
}
|
||||
if constexpr (need_acc) {
|
||||
vec::convert(buffer_ptr, self_ptr, K);
|
||||
}
|
||||
|
||||
// step 3: finalize
|
||||
@ -738,8 +769,8 @@ void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index,
|
||||
}
|
||||
|
||||
void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, self.scalar_type(), "scatter_add_expanded_index", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_add_expanded_index", [&] {
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, /*include_self*/true);
|
||||
});
|
||||
}
|
||||
@ -747,8 +778,8 @@ void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index,
|
||||
void scatter_reduce_expanded_index_kernel(
|
||||
const Tensor& self, const Tensor& index, const Tensor& src,
|
||||
const ReductionType& reduction, bool include_self) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
|
||||
AT_DISPATCH_REDUCTION_TYPES(reduction, [&]() {
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, reduce>(self, index, src, include_self);
|
||||
});
|
||||
|
@ -277,11 +277,16 @@ class TestScatterGather(TestCase):
|
||||
self.assertEqual(input, expected_result)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||
def test_scatter_expanded_index(self, device, dtype):
|
||||
def helper(input_size, idx_size):
|
||||
def helper(input_size, idx_size, atol=1e-5, rtol=0.016):
|
||||
is_reduced_type = dtype in [torch.bfloat16, torch.float16]
|
||||
if is_reduced_type:
|
||||
atol = 1e-2
|
||||
rtol = 1e-2
|
||||
input = torch.randn(input_size, device=device).to(dtype=dtype)
|
||||
input2 = input.clone()
|
||||
input2 = input.clone().to(torch.float32) if is_reduced_type else input.clone()
|
||||
input3 = input.clone()
|
||||
|
||||
shape = [1] * len(input_size)
|
||||
shape[0] = idx_size
|
||||
@ -300,17 +305,27 @@ class TestScatterGather(TestCase):
|
||||
idx = idx.expand(expanded_shape)
|
||||
idx2 = idx.contiguous()
|
||||
src = torch.randn(expanded_shape, device=device).to(dtype=dtype)
|
||||
src2 = src.clone().to(torch.float32) if is_reduced_type else src.clone()
|
||||
|
||||
out = input.scatter_add(0, idx, src)
|
||||
out2 = input2.scatter_add(0, idx2, src)
|
||||
out2 = input2.scatter_add(0, idx2, src2)
|
||||
|
||||
self.assertEqual(out, out2)
|
||||
if torch.has_openmp:
|
||||
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2, atol=atol, rtol=rtol)
|
||||
else:
|
||||
out3 = input3.scatter_add(0, idx2, src)
|
||||
self.assertEqual(out, out3)
|
||||
|
||||
for reduce in ["sum", "prod", "mean", "amax", "amin"]:
|
||||
for include_self in [True, False]:
|
||||
out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self)
|
||||
out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
|
||||
self.assertEqual(out, out2)
|
||||
out2 = input2.scatter_reduce(0, idx2, src2, reduce=reduce, include_self=include_self)
|
||||
if torch.has_openmp:
|
||||
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2,
|
||||
atol=atol, rtol=rtol)
|
||||
else:
|
||||
out3 = input3.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
|
||||
self.assertEqual(out, out3)
|
||||
|
||||
helper([50, 17], 100)
|
||||
helper([50, 1], 100)
|
||||
|
@ -2517,6 +2517,7 @@ class TestSparseCSR(TestCase):
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@precisionOverride({torch.bfloat16: 0.01})
|
||||
def test_sparse_mm_reduce(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
|
||||
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
||||
|
Reference in New Issue
Block a user