Minor fixes in fastrnns benchmarks

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18613

Reviewed By: wanchaol

Differential Revision: D14681838

fbshipit-source-id: 60bd5c9b09398c74335f003cd21ea32dd1c45876
This commit is contained in:
Junjie Bai
2019-03-29 01:16:52 -07:00
committed by Facebook Github Bot
parent d859031ebf
commit e22a2b9015
2 changed files with 4 additions and 5 deletions

View File

@ -1,6 +1,5 @@
from .cells import *
from .factory import *
from .test import *
# (output, next_state) = cell(input, state)
seqLength = 100

View File

@ -6,7 +6,7 @@ import time
import torch
import datetime
from .runner import get_rnn_runners
from .runner import get_nn_runners
PY3 = sys.version_info >= (3, 0)
@ -48,7 +48,7 @@ def profile(rnns, sleep_between_seconds=1, nloops=5,
params = dict(seqLength=seqLength, numLayers=numLayers,
inputSize=inputSize, hiddenSize=hiddenSize,
miniBatch=miniBatch, device=device, seed=seed)
for name, creator, context in get_rnn_runners(*rnns):
for name, creator, context in get_nn_runners(*rnns):
with context():
run_rnn(name, creator, nloops, **params)
time.sleep(sleep_between_seconds)
@ -94,11 +94,11 @@ def nvprof(cmd, outpath):
def full_profile(rnns, **args):
args['internal_run'] = True
profile_args = []
for k, v in args.items():
profile_args.append('--{}={}'.format(k, v))
profile_args.append('--rnns {}'.format(' '.join(rnns)))
profile_args.append('--internal_run')
outpath = nvprof_output_filename(rnns, **args)
@ -125,7 +125,7 @@ if __name__ == '__main__':
# if internal_run, we actually run the rnns.
# if not internal_run, we shell out to nvprof with internal_run=T
parser.add_argument('--internal_run', default=False, type=bool,
parser.add_argument('--internal_run', default=False, action='store_true',
help='Don\'t use this')
args = parser.parse_args()
if args.rnns is None: