mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize scatter_add/scatter_reduce in BFloat16/Half data type in CPU backend (#103427)
### Description This PR is to optimize scatter_add/scatter_reduce of BFloat16/Half data type in CPU backend, which is one task in https://github.com/pyg-team/pytorch_geometric/issues/7057. Main point is creating a buffer among threads to accumulate intermediate data as fp32 data type. Next step: - [x] Add benchmarks - [x] Extend to Half - [x] Simplify code ### Performance test (Updated) Test BFloat16 in Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz With jemalloc and iomp Single socket (40C)  Single core  Pull Request resolved: https://github.com/pytorch/pytorch/pull/103427 Approved by: https://github.com/mingfeima, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf127d236a
commit
da7675621e
@ -277,11 +277,16 @@ class TestScatterGather(TestCase):
|
||||
self.assertEqual(input, expected_result)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||
def test_scatter_expanded_index(self, device, dtype):
|
||||
def helper(input_size, idx_size):
|
||||
def helper(input_size, idx_size, atol=1e-5, rtol=0.016):
|
||||
is_reduced_type = dtype in [torch.bfloat16, torch.float16]
|
||||
if is_reduced_type:
|
||||
atol = 1e-2
|
||||
rtol = 1e-2
|
||||
input = torch.randn(input_size, device=device).to(dtype=dtype)
|
||||
input2 = input.clone()
|
||||
input2 = input.clone().to(torch.float32) if is_reduced_type else input.clone()
|
||||
input3 = input.clone()
|
||||
|
||||
shape = [1] * len(input_size)
|
||||
shape[0] = idx_size
|
||||
@ -300,17 +305,27 @@ class TestScatterGather(TestCase):
|
||||
idx = idx.expand(expanded_shape)
|
||||
idx2 = idx.contiguous()
|
||||
src = torch.randn(expanded_shape, device=device).to(dtype=dtype)
|
||||
src2 = src.clone().to(torch.float32) if is_reduced_type else src.clone()
|
||||
|
||||
out = input.scatter_add(0, idx, src)
|
||||
out2 = input2.scatter_add(0, idx2, src)
|
||||
out2 = input2.scatter_add(0, idx2, src2)
|
||||
|
||||
self.assertEqual(out, out2)
|
||||
if torch.has_openmp:
|
||||
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2, atol=atol, rtol=rtol)
|
||||
else:
|
||||
out3 = input3.scatter_add(0, idx2, src)
|
||||
self.assertEqual(out, out3)
|
||||
|
||||
for reduce in ["sum", "prod", "mean", "amax", "amin"]:
|
||||
for include_self in [True, False]:
|
||||
out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self)
|
||||
out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
|
||||
self.assertEqual(out, out2)
|
||||
out2 = input2.scatter_reduce(0, idx2, src2, reduce=reduce, include_self=include_self)
|
||||
if torch.has_openmp:
|
||||
self.assertEqual(out, out2.to(dtype) if is_reduced_type else out2,
|
||||
atol=atol, rtol=rtol)
|
||||
else:
|
||||
out3 = input3.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
|
||||
self.assertEqual(out, out3)
|
||||
|
||||
helper([50, 17], 100)
|
||||
helper([50, 1], 100)
|
||||
|
Reference in New Issue
Block a user