mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Change loop unrolling strategy. Previously, the script only unrolls the inner loop over block_size when block size is multiple of vector length. This version instead unrolls the outer loop which reduces the number of load/store for accumulation into the output array and improves performance for cases when block size is not multiple of vector length. Benchmarking script: ```python # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com> # SPDX-License-Identifier: BSD-3-Clause import torch import torch.nn as nn import numpy as np import time import sys np.random.seed(0) torch.manual_seed(0) num_embeddings = 400000 embedding_dim = int(sys.argv[1]) multi_hot = 100 batch_size = 400 nrun = 1000 class SimpleEmbeddingBagModel(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(SimpleEmbeddingBagModel, self).__init__() weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)).to(torch.float16) # Defining the EmbeddingBag layer self.embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, _weight=weights, mode='sum', include_last_offset=True, dtype=torch.float32) def forward(self, input, offsets): # Forward pass through the EmbeddingBag layer result32 = self.embedding_bag(input, offsets, per_sample_weights=None) return result32 # Instantiate the model model = SimpleEmbeddingBagModel(num_embeddings=num_embeddings, embedding_dim=embedding_dim) model.eval() # Example input input_tensor = torch.randint(0, num_embeddings, (batch_size * multi_hot,), dtype=torch.long) offsets = torch.tensor(range(0, batch_size * multi_hot + 1, multi_hot)) with torch.no_grad(): # warm up output32 = model(input_tensor, offsets) ti = time.time_ns() for i in range(nrun): _ = model(input_tensor, offsets) tf = time.time_ns() print("{:3d} {:.3E}".format(embedding_dim, (tf-ti)/nrun/1.e6)) ``` Speedup on NEOVERSEV1 with 1 thread  Pull Request resolved: https://github.com/pytorch/pytorch/pull/150176 Approved by: https://github.com/digantdesai, https://github.com/malfet