mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Minor fixes in fastrnns benchmarks
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18613 Reviewed By: wanchaol Differential Revision: D14681838 fbshipit-source-id: 60bd5c9b09398c74335f003cd21ea32dd1c45876
This commit is contained in:
committed by
Facebook Github Bot
parent
d859031ebf
commit
e22a2b9015
@ -1,6 +1,5 @@
|
||||
from .cells import *
|
||||
from .factory import *
|
||||
from .test import *
|
||||
|
||||
# (output, next_state) = cell(input, state)
|
||||
seqLength = 100
|
||||
|
@ -6,7 +6,7 @@ import time
|
||||
import torch
|
||||
import datetime
|
||||
|
||||
from .runner import get_rnn_runners
|
||||
from .runner import get_nn_runners
|
||||
|
||||
PY3 = sys.version_info >= (3, 0)
|
||||
|
||||
@ -48,7 +48,7 @@ def profile(rnns, sleep_between_seconds=1, nloops=5,
|
||||
params = dict(seqLength=seqLength, numLayers=numLayers,
|
||||
inputSize=inputSize, hiddenSize=hiddenSize,
|
||||
miniBatch=miniBatch, device=device, seed=seed)
|
||||
for name, creator, context in get_rnn_runners(*rnns):
|
||||
for name, creator, context in get_nn_runners(*rnns):
|
||||
with context():
|
||||
run_rnn(name, creator, nloops, **params)
|
||||
time.sleep(sleep_between_seconds)
|
||||
@ -94,11 +94,11 @@ def nvprof(cmd, outpath):
|
||||
|
||||
|
||||
def full_profile(rnns, **args):
|
||||
args['internal_run'] = True
|
||||
profile_args = []
|
||||
for k, v in args.items():
|
||||
profile_args.append('--{}={}'.format(k, v))
|
||||
profile_args.append('--rnns {}'.format(' '.join(rnns)))
|
||||
profile_args.append('--internal_run')
|
||||
|
||||
outpath = nvprof_output_filename(rnns, **args)
|
||||
|
||||
@ -125,7 +125,7 @@ if __name__ == '__main__':
|
||||
|
||||
# if internal_run, we actually run the rnns.
|
||||
# if not internal_run, we shell out to nvprof with internal_run=T
|
||||
parser.add_argument('--internal_run', default=False, type=bool,
|
||||
parser.add_argument('--internal_run', default=False, action='store_true',
|
||||
help='Don\'t use this')
|
||||
args = parser.parse_args()
|
||||
if args.rnns is None:
|
||||
|
Reference in New Issue
Block a user