mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
## Motivation Many PRs optimizing samplers (for eg https://github.com/pytorch/pytorch/pull/147706, https://github.com/pytorch/pytorch/pull/137423) are leveraging an adhoc script for benchmarking samplers. The script and outputs are often copied over in PRs. We want to begin centralizing benchmarks for torch.utils.data components. ## What ? * This PR adds a new sub-folder in `benchmarks` for `data`. This is aimed to cover benchmarking scripts for torch.utils.data components like dataloader and sampler. * Specifically, this PR includes a simple script to time samplers. This is often "copy-pasted" in PRs optimizing samplers. Having it in a centralized location should prevent that, and allow a common standard. ## Output ``` Benchmark Results: +--------------+-------------+----------------+-----------+-----------+ | Batch Size | Drop Last | Original (s) | New (s) | Speedup | +==============+=============+================+===========+===========+ | 4 | True | 0.004 | 0.0088 | -119.62% | +--------------+-------------+----------------+-----------+-----------+ | 4 | False | 0.0083 | 0.009 | -9.23% | +--------------+-------------+----------------+-----------+-----------+ | 8 | True | 0.003 | 0.0074 | -147.64% | +--------------+-------------+----------------+-----------+-----------+ | 8 | False | 0.0054 | 0.0075 | -38.72% | +--------------+-------------+----------------+-----------+-----------+ | 64 | True | 0.0021 | 0.0056 | -161.92% | +--------------+-------------+----------------+-----------+-----------+ | 64 | False | 0.0029 | 0.0055 | -92.50% | +--------------+-------------+----------------+-----------+-----------+ | 640 | True | 0.002 | 0.0055 | -168.75% | +--------------+-------------+----------------+-----------+-----------+ | 640 | False | 0.0024 | 0.0062 | -161.35% | +--------------+-------------+----------------+-----------+-----------+ | 6400 | True | 0.0021 | 0.0055 | -160.13% | +--------------+-------------+----------------+-----------+-----------+ | 6400 | False | 0.0021 | 0.0068 | -215.46% | +--------------+-------------+----------------+-----------+-----------+ | 64000 | True | 0.0042 | 0.0065 | -55.29% | +--------------+-------------+----------------+-----------+-----------+ | 64000 | False | 0.0029 | 0.0077 | -169.56% | +--------------+-------------+----------------+-----------+-----------+ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156974 Approved by: https://github.com/ramanishsingh
144 lines
4.5 KiB
Python
144 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import time
|
|
from collections.abc import Iterable, Iterator
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
from tabulate import tabulate
|
|
|
|
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
|
|
|
|
|
class NewBatchSampler(Sampler[list[int]]):
|
|
"""Alternative implementation of BatchSampler for benchmarking purposes."""
|
|
|
|
def __init__(
|
|
self,
|
|
sampler: Union[Sampler[int], Iterable[int]],
|
|
batch_size: int,
|
|
drop_last: bool,
|
|
) -> None:
|
|
if (
|
|
not isinstance(batch_size, int)
|
|
or isinstance(batch_size, bool)
|
|
or batch_size <= 0
|
|
):
|
|
raise ValueError(
|
|
f"batch_size should be a positive integer value, but got batch_size={batch_size}"
|
|
)
|
|
if not isinstance(drop_last, bool):
|
|
raise ValueError(
|
|
f"drop_last should be a boolean value, but got drop_last={drop_last}"
|
|
)
|
|
self.sampler = sampler
|
|
self.batch_size = batch_size
|
|
self.drop_last = drop_last
|
|
|
|
def __iter__(self) -> Iterator[list[int]]:
|
|
if self.drop_last:
|
|
sampler_iter = iter(self.sampler)
|
|
while True:
|
|
try:
|
|
batch = [next(sampler_iter) for _ in range(self.batch_size)]
|
|
yield batch
|
|
except StopIteration:
|
|
break
|
|
else:
|
|
batch = [0] * self.batch_size
|
|
idx_in_batch = 0
|
|
for idx in self.sampler:
|
|
batch[idx_in_batch] = idx
|
|
idx_in_batch += 1
|
|
if idx_in_batch == self.batch_size:
|
|
yield batch
|
|
idx_in_batch = 0
|
|
batch = [0] * self.batch_size
|
|
if idx_in_batch > 0:
|
|
yield batch[:idx_in_batch]
|
|
|
|
def __len__(self) -> int:
|
|
# Can only be called if self.sampler has __len__ implemented
|
|
if self.drop_last:
|
|
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
|
|
else:
|
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
|
|
|
|
|
|
def main():
|
|
"""Run benchmark with specified parameters."""
|
|
DATA_SIZE = 99999
|
|
AVG_TIMES = 10
|
|
BATCH_SIZES = [4, 8, 64, 640, 6400, 64000]
|
|
DROP_LAST_OPTIONS = [True, False]
|
|
|
|
results = []
|
|
|
|
# Set up samplers here, ensure right args are passed in
|
|
baselineSampler = BatchSampler
|
|
testSampler = NewBatchSampler
|
|
|
|
for batch_size in BATCH_SIZES:
|
|
for drop_last in DROP_LAST_OPTIONS:
|
|
print(f"Benchmarking with batch_size={batch_size}, drop_last={drop_last}")
|
|
|
|
# Benchmark baselineSampler
|
|
original_times = []
|
|
for _ in range(AVG_TIMES):
|
|
start = time.perf_counter()
|
|
for _ in baselineSampler(
|
|
sampler=SequentialSampler(range(DATA_SIZE)),
|
|
batch_size=batch_size,
|
|
drop_last=drop_last,
|
|
):
|
|
pass
|
|
end = time.perf_counter()
|
|
original_times.append(end - start)
|
|
time.sleep(0.1)
|
|
|
|
original_avg = float(np.mean(original_times))
|
|
|
|
# Benchmark testSampler
|
|
new_times = []
|
|
for _ in range(AVG_TIMES):
|
|
start = time.perf_counter()
|
|
for _ in testSampler(
|
|
sampler=SequentialSampler(range(DATA_SIZE)),
|
|
batch_size=batch_size,
|
|
drop_last=drop_last,
|
|
):
|
|
pass
|
|
end = time.perf_counter()
|
|
new_times.append(end - start)
|
|
time.sleep(0.1) # Small delay to reduce system load
|
|
|
|
new_avg = float(np.mean(new_times))
|
|
|
|
# Calculate speedup
|
|
if original_avg > 0 and new_avg > 0:
|
|
speedup = (original_avg - new_avg) / original_avg * 100
|
|
speedup_str = f"{speedup:.2f}%"
|
|
else:
|
|
speedup_str = "N/A"
|
|
|
|
print(f"Speedup: {speedup_str}\n")
|
|
|
|
results.append(
|
|
[
|
|
batch_size,
|
|
drop_last,
|
|
f"{original_avg:.4f}",
|
|
f"{new_avg:.4f}",
|
|
speedup_str,
|
|
]
|
|
)
|
|
|
|
# Print results in a table
|
|
headers = ["Batch Size", "Drop Last", "Original (s)", "New (s)", "Speedup"]
|
|
print("\nBenchmark Results:")
|
|
print(tabulate(results, headers=headers, tablefmt="grid"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|