Files
pytorch/benchmarks/framework_overhead_benchmark/SimpleAddModule.py
Xiang Gao 20ac736200 Remove py2 compatible future imports (#44735)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44735

Reviewed By: mruberry

Differential Revision: D23731306

Pulled By: ezyang

fbshipit-source-id: 0ba009a99e475ddbe22981be8ac636f8a1c8b02f
2020-09-16 12:55:57 -07:00

17 lines
389 B
Python

import torch
from utils import NUM_LOOP_ITERS
def add_tensors_loop(x, y):
z = torch.add(x, y)
for i in range(NUM_LOOP_ITERS):
z = torch.add(z, x)
return z
class SimpleAddModule(torch.nn.Module):
def __init__(self, add_op):
super(SimpleAddModule, self).__init__()
self.add_op = add_op
def forward(self, x, y):
return self.add_op(x, y)