mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix performance regression when indexing by Numpy arrays (#163280)
Benchmark script: ``` import time import numpy as np import torch def main() -> None: for i in range(10): block_indices = np.arange(16384, dtype=np.int32) block_indices = block_indices.reshape(-1).clip(max=255) batch_indices = np.zeros(16384, dtype=np.int64) virtual_batches = 32 block_table = torch.randn(32, 256) start = time.perf_counter() block_table[batch_indices, block_indices].view(virtual_batches, -1) end = time.perf_counter() time_elapsed_ms = (end - start) * 1000 print(f"Function execution time: {time_elapsed_ms:.1f}ms") if __name__ == "__main__": main() ``` Before: ``` (a) [ezyang@devvm006.dkl0 ~/local/b/pytorch] python ben.py Function execution time: 28.5ms Function execution time: 12.9ms Function execution time: 12.6ms Function execution time: 13.5ms Function execution time: 12.0ms Function execution time: 13.4ms Function execution time: 12.9ms Function execution time: 12.9ms Function execution time: 13.1ms Function execution time: 13.0ms ``` After: ``` Function execution time: 17.8ms Function execution time: 2.5ms Function execution time: 1.3ms Function execution time: 2.5ms Function execution time: 2.3ms Function execution time: 1.3ms Function execution time: 2.4ms Function execution time: 2.5ms Function execution time: 2.5ms Function execution time: 2.4ms ``` Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/163280 Approved by: https://github.com/SherlockNoMad, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
3016616ccb
commit
c91f59b1a0
@ -109,7 +109,9 @@ static int64_t count_specified_dimensions(PyObject* index) {
|
||||
}
|
||||
} else {
|
||||
// Check sequences for __torch_function__ (top-level only)
|
||||
if (PySequence_Check(obj)) {
|
||||
// NB: do NOT use PySequence_Check, that will grab things like Numpy
|
||||
// arrays
|
||||
if (PyTuple_Check(obj) || PyList_Check(obj)) {
|
||||
if (sequence_has_torch_function(obj)) {
|
||||
return -1; // Signal torch function handling needed
|
||||
}
|
||||
|
Reference in New Issue
Block a user