mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Benchmarks: add scripts for FastRNNs results comparison. (#44134)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44134 Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D23505810 Pulled By: ZolotukhinM fbshipit-source-id: d0b3d70d4c2a44a8c3773631d09a25a98ec59370
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3806c939bd
commit
d0421ff1cc
55
benchmarks/compare-fastrnn-results.py
Normal file
55
benchmarks/compare-fastrnn-results.py
Normal file
@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
import json
|
||||
from collections import namedtuple
|
||||
|
||||
Result = namedtuple("Result", ["name", "base_time", "diff_time"])
|
||||
|
||||
def construct_name(fwd_bwd, test_name):
|
||||
bwd = 'backward' in fwd_bwd
|
||||
suite_name = fwd_bwd.replace('-backward', '')
|
||||
return '{suite}[{test}]:{fwd_bwd}'.format(suite=suite_name, test=test_name, fwd_bwd='bwd' if bwd else 'fwd')
|
||||
|
||||
def get_times(json_data):
|
||||
r = {}
|
||||
for fwd_bwd in json_data:
|
||||
for test_name in json_data[fwd_bwd]:
|
||||
name = construct_name(fwd_bwd, test_name)
|
||||
r[name] = json_data[fwd_bwd][test_name]
|
||||
return r
|
||||
|
||||
parser = argparse.ArgumentParser("compare two pytest jsons")
|
||||
parser.add_argument('base', help="base json file")
|
||||
parser.add_argument('diff', help='diff json file')
|
||||
parser.add_argument('--format', default='md', type=str, help='output format (csv, md, json, table)')
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.base, "r") as base:
|
||||
base_times = get_times(json.load(base))
|
||||
with open(args.diff, "r") as diff:
|
||||
diff_times = get_times(json.load(diff))
|
||||
|
||||
all_keys = set(base_times.keys()).union(diff_times.keys())
|
||||
results = [
|
||||
Result(name, base_times.get(name, float("nan")), diff_times.get(name, float("nan")))
|
||||
for name in sorted(all_keys)
|
||||
]
|
||||
|
||||
header_fmt = {'table' : '{:48s} {:>13s} {:>15s} {:>10s}',
|
||||
'md' : '| {:48s} | {:>13s} | {:>15s} | {:>10s} |',
|
||||
'csv' : '{:s}, {:s}, {:s}, {:s}'}
|
||||
data_fmt = {'table' : '{:48s} {:13.6f} {:15.6f} {:9.1f}%',
|
||||
'md' : '| {:48s} | {:13.6f} | {:15.6f} | {:9.1f}% |',
|
||||
'csv' : '{:s}, {:.6f}, {:.6f}, {:.2f}%'}
|
||||
|
||||
if args.format in ['table', 'md', 'csv']:
|
||||
header_fmt_str = header_fmt[args.format]
|
||||
data_fmt_str = data_fmt[args.format]
|
||||
print(header_fmt_str.format("name", "base time (s)", "diff time (s)", "% change"))
|
||||
if args.format == 'md':
|
||||
print(header_fmt_str.format(":---", "---:", "---:", "---:"))
|
||||
for r in results:
|
||||
print(data_fmt_str.format(r.name, r.base_time, r.diff_time, (r.diff_time / r.base_time - 1.0) * 100.0))
|
||||
elif args.format == 'json':
|
||||
print(json.dumps(results))
|
||||
else:
|
||||
raise ValueError('Unknown output format: ' + args.format)
|
4
benchmarks/compare.sh
Normal file
4
benchmarks/compare.sh
Normal file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
python -m fastrnns.bench --fuser=old --group=rnns --print-json oss > old.json
|
||||
python -m fastrnns.bench --fuser=te --group=rnns --print-json oss > te.json
|
||||
python compare-fastrnn-results.py old.json te.json --format md
|
Reference in New Issue
Block a user