mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
optimize gather performance for gnn usage on CPU (#87586)
On classic pyg user case for message passing, `gather` has `index` tensor in a broadcasted shape, e.g. with shape `5000, 128` and stride `[1, 0]`. That indicated gather is done on each row of the self tensor. The current implementation will try to parallel on the inner dimension which is bad performance for CPU and unable to be vectorized. This PR addressed this use case and optimize in a similar manner to index_select, parallel on outer dimension of `index` and do vectorized copy on inner dimension. Performance benchmarking on Xeon Icelake single socket on `GCN`: the `gather` reduced from `150.787ms` to `10.926ms`, after this optimization, `gather` will no longer be the major bottleneck for training of GNN models when `EdgeIndex` is in COO format. for more details, please refer to https://github.com/pyg-team/pytorch_geometric/issues/4891#issuecomment-1288423705 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87586 Approved by: https://github.com/rusty1s, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
f8026413f5
commit
dc6916b341
@ -305,6 +305,34 @@ class TestScatterGather(TestCase):
|
||||
helper([50, 8, 7], 100)
|
||||
helper([50, 3, 4, 5], 100)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
def test_gather_expanded_index(self, device, dtype):
|
||||
def helper(input_size, idx_size):
|
||||
input = torch.randn(input_size, device=device).to(dtype=dtype)
|
||||
input2 = input.clone()
|
||||
|
||||
shape = [1] * len(input_size)
|
||||
shape[0] = idx_size
|
||||
dim_size = input_size[0]
|
||||
idx = torch.randint(0, dim_size, shape)
|
||||
|
||||
# Test the fast path on gather when index is expanded
|
||||
expanded_shape = input_size
|
||||
expanded_shape[0] = idx_size
|
||||
idx = idx.expand(expanded_shape)
|
||||
idx2 = idx.contiguous()
|
||||
|
||||
out = torch.gather(input, 0, idx)
|
||||
out2 = torch.gather(input2, 0, idx2)
|
||||
|
||||
self.assertEqual(out, out2)
|
||||
|
||||
helper([50, 17], 100)
|
||||
helper([50, 1], 100)
|
||||
helper([50, 8, 7], 100)
|
||||
helper([50, 3, 4, 5], 100)
|
||||
|
||||
# Generic Device Test Framework instantation, see
|
||||
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
||||
# for details.
|
||||
|
Reference in New Issue
Block a user