Files
pytorch/benchmarks/instruction_counts/definitions/setup.py
2023-07-26 01:18:48 +00:00

48 lines
1.6 KiB
Python

"""Define some common setup blocks which benchmarks can reuse."""
import enum
from core.api import GroupedSetup
from core.utils import parse_stmts
_TRIVIAL_2D = GroupedSetup(r"x = torch.ones((4, 4))", r"auto x = torch::ones({4, 4});")
_TRIVIAL_3D = GroupedSetup(
r"x = torch.ones((4, 4, 4))", r"auto x = torch::ones({4, 4, 4});"
)
_TRIVIAL_4D = GroupedSetup(
r"x = torch.ones((4, 4, 4, 4))", r"auto x = torch::ones({4, 4, 4, 4});"
)
_TRAINING = GroupedSetup(
*parse_stmts(
r"""
Python | C++
---------------------------------------- | ----------------------------------------
# Inputs | // Inputs
x = torch.ones((1,)) | auto x = torch::ones({1});
y = torch.ones((1,)) | auto y = torch::ones({1});
|
# Weights | // Weights
w0 = torch.ones( | auto w0 = torch::ones({1});
(1,), requires_grad=True) | w0.set_requires_grad(true);
w1 = torch.ones( | auto w1 = torch::ones({1});
(1,), requires_grad=True) | w1.set_requires_grad(true);
w2 = torch.ones( | auto w2 = torch::ones({2});
(2,), requires_grad=True) | w2.set_requires_grad(true);
"""
)
)
class Setup(enum.Enum):
TRIVIAL_2D = _TRIVIAL_2D
TRIVIAL_3D = _TRIVIAL_3D
TRIVIAL_4D = _TRIVIAL_4D
TRAINING = _TRAINING