Change loop unrolling strategy. Previously, the script only unrolls the inner loop over block_size when block size is multiple of vector length. This version instead unrolls the outer loop which reduces the number of load/store for accumulation into the output array and improves performance for cases when block size is not multiple of vector length.
Benchmarking script:
```python
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
import numpy as np
import time
import sys
np.random.seed(0)
torch.manual_seed(0)
num_embeddings = 400000
embedding_dim = int(sys.argv[1])
multi_hot = 100
batch_size = 400
nrun = 1000
class SimpleEmbeddingBagModel(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super(SimpleEmbeddingBagModel, self).__init__()
weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)).to(torch.float16)
# Defining the EmbeddingBag layer
self.embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, _weight=weights,
mode='sum', include_last_offset=True, dtype=torch.float32)
def forward(self, input, offsets):
# Forward pass through the EmbeddingBag layer
result32 = self.embedding_bag(input, offsets, per_sample_weights=None)
return result32
# Instantiate the model
model = SimpleEmbeddingBagModel(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
model.eval()
# Example input
input_tensor = torch.randint(0, num_embeddings, (batch_size * multi_hot,), dtype=torch.long)
offsets = torch.tensor(range(0, batch_size * multi_hot + 1, multi_hot))
with torch.no_grad():
# warm up
output32 = model(input_tensor, offsets)
ti = time.time_ns()
for i in range(nrun):
_ = model(input_tensor, offsets)
tf = time.time_ns()
print("{:3d} {:.3E}".format(embedding_dim, (tf-ti)/nrun/1.e6))
```
Speedup on NEOVERSEV1 with 1 thread

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150176
Approved by: https://github.com/digantdesai, https://github.com/malfet