Optimize scatter_add/scatter_reduce in BFloat16/Half data type in CPU backend (#103427)

### Description

This PR is to optimize scatter_add/scatter_reduce of BFloat16/Half data type in CPU backend, which is one task in https://github.com/pyg-team/pytorch_geometric/issues/7057. Main point is creating a buffer among threads to accumulate intermediate data as fp32 data type.

Next step:

 - [x] Add benchmarks
 - [x] Extend to Half
 - [x] Simplify code

### Performance test (Updated)

Test BFloat16 in Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz
With jemalloc and iomp

Single socket (40C)
![image](https://github.com/pytorch/pytorch/assets/61222868/4b4342f1-8cc3-46f7-81f5-651becd9b1e3)

Single core
![image](https://github.com/pytorch/pytorch/assets/61222868/09e5f700-2c2e-4208-979e-74b85474dea6)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103427
Approved by: https://github.com/mingfeima, https://github.com/albanD
This commit is contained in:
yanbing-j
2023-07-06 01:23:56 +00:00
committed by PyTorch MergeBot
parent bf127d236a
commit da7675621e
4 changed files with 122 additions and 13 deletions

View File

@ -277,11 +277,16 @@ class TestScatterGather(TestCase):
self.assertEqual(input, expected_result)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
def test_scatter_expanded_index(self, device, dtype):
def helper(input_size, idx_size):
def helper(input_size, idx_size, atol=1e-5, rtol=0.016):
is_reduced_type = dtype in [torch.bfloat16, torch.float16]
if is_reduced_type:
atol = 1e-2
rtol = 1e-2
input = torch.randn(input_size, device=device).to(dtype=dtype)
input2 = input.clone()
input2 = input.clone().to(torch.float32) if is_reduced_type else input.clone()
input3 = input.clone()
shape = [1] * len(input_size)
shape[0] = idx_size
@ -300,17 +305,27 @@ class TestScatterGather(TestCase):
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
src = torch.randn(expanded_shape, device=device).to(dtype=dtype)
src2 = src.clone().to(torch.float32) if is_reduced_type else src.clone()
out = input.scatter_add(0, idx, src)
out2 = input2.scatter_add(0, idx2, src)
out2 = input2.scatter_add(0, idx2, src2)
self.assertEqual(out, out2)
if torch.has_openmp:
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2, atol=atol, rtol=rtol)
else:
out3 = input3.scatter_add(0, idx2, src)
self.assertEqual(out, out3)
for reduce in ["sum", "prod", "mean", "amax", "amin"]:
for include_self in [True, False]:
out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self)
out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
self.assertEqual(out, out2)
out2 = input2.scatter_reduce(0, idx2, src2, reduce=reduce, include_self=include_self)
if torch.has_openmp:
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2,
atol=atol, rtol=rtol)
else:
out3 = input3.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
self.assertEqual(out, out3)
helper([50, 17], 100)
helper([50, 1], 100)