mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129754 Approved by: https://github.com/ezyang
32 lines
778 B
Python
32 lines
778 B
Python
import argparse
|
|
|
|
from common import SubTensor, SubWithTorchFunction, WithTorchFunction # noqa: F401
|
|
|
|
import torch
|
|
|
|
|
|
Tensor = torch.tensor
|
|
|
|
NUM_REPEATS = 1000000
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Run the torch.add for a given class a given number of times."
|
|
)
|
|
parser.add_argument(
|
|
"tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
|
|
)
|
|
parser.add_argument(
|
|
"--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
TensorClass = globals()[args.tensor_class]
|
|
NUM_REPEATS = args.nreps
|
|
|
|
t1 = TensorClass([1.0])
|
|
t2 = TensorClass([2.0])
|
|
|
|
for _ in range(NUM_REPEATS):
|
|
torch.add(t1, t2)
|