[functorch] Added decomposition testing + display infra (pytorch/functorch#281)

* Added decomposition testing + display infra

* add a couple more decompositions

* changed some stuff

* made some changes

* Added decomposition testing + display infra

* add a couple more decompositions

* fix some decompositions

* changed some stuff

* updated generation

* fix test failures

* removed extraneous files

* fixed test failures

* fixed tests

* updated

* fixed tests again
This commit is contained in:
Horace He
2021-11-29 18:19:36 -05:00
committed by Jon Janzen
parent 0fa9f7af83
commit fe9aac72d1
10 changed files with 698 additions and 2291 deletions

View File

@ -76,94 +76,100 @@ def get_ops_for_key(key):
cleaned_ops.append(i[6:].strip())
return set(cleaned_ops)
batched_registrations = get_ops_for_key('FuncTorchBatched')
all_ops = get_ops_for_key(None)
def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
noncomposite_ops = all_ops - composite_ops
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
from collections import defaultdict
vmap_ops = batched_registrations
noncomposite_ops = all_ops - composite_ops
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
from collections import defaultdict
uniq_ops = []
uniq_names = set()
overload_types = defaultdict(list)
cnt = 0
for op in ops:
func_str = op['func']
name = func_str[:func_str.index('(')]
if '.' in name:
uniq_name = name[:name.index('.')]
overload_types[name[name.index('.') + 1:]].append(name)
else:
uniq_name = name
op['name'] = uniq_name
full_name = func_str[:func_str.index('(')]
op['full_name'] = full_name
ret_type = func_str[func_str.index('->') + 3:]
op['ret_type'] = ret_type
cnt += 1
if uniq_name in uniq_names:
continue
uniq_names.add(uniq_name)
uniq_ops.append(op)
def annotate_ops(ops, is_unique):
categorization = defaultdict(int)
uniq_ops = []
uniq_names = set()
overload_types = defaultdict(list)
cnt = 0
for op in ops:
old_tcnt = sum(categorization.values())
if op['name'][-1] == '_':
categorization['inplace'] += 1
op['meta'] = 'inplace'
continue
if not is_unique and 'a!' in op['func'].lower():
categorization['out'] += 1
op['meta'] = 'out'
continue
if 'conv' in op['name']:
categorization['conv'] += 1
op['meta'] = 'conv'
continue
if 'pool' in op['name']:
categorization['pool'] += 1
op['meta'] = 'pool'
continue
if 'backward' in op['name']:
categorization['backward'] += 1
op['meta'] = 'backward'
continue
if op['name'][0] == '_' and op['name'][1] != '_':
categorization['private'] += 1
op['meta'] = 'private'
continue
if 'batch_norm' in op['name']:
categorization['batch_norm'] += 1
op['meta'] = 'batch_norm'
continue
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
categorization['non_tensor'] += 1
op['meta'] = 'non_tensor'
continue
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
categorization['backend'] += 1
op['meta'] = 'backend'
continue
if op['name'] in annotated_ops:
categorization['core'] += 1
op['meta'] = 'core ' + annotated_ops[op['name']]
func_str = op['func']
name = func_str[:func_str.index('(')]
if '.' in name:
uniq_name = name[:name.index('.')]
overload_types[name[name.index('.') + 1:]].append(name)
else:
categorization['core'] += 1
op['meta'] = 'core unknown'
return categorization
uniq_name = name
op['name'] = uniq_name
full_name = func_str[:func_str.index('(')]
op['full_name'] = full_name
ret_type = func_str[func_str.index('->') + 3:]
op['ret_type'] = ret_type
cnt += 1
if uniq_name in uniq_names:
continue
uniq_names.add(uniq_name)
uniq_ops.append(op)
# categorization = annotate_ops(uniq_ops, True)
categorization = annotate_ops(ops, False)
def annotate_ops(ops, is_unique):
categorization = defaultdict(int)
for op in ops:
old_tcnt = sum(categorization.values())
if op['name'][-1] == '_':
categorization['inplace'] += 1
op['meta'] = 'inplace'
continue
if not is_unique and 'a!' in op['func'].lower():
categorization['out'] += 1
op['meta'] = 'out'
continue
if 'conv' in op['name']:
categorization['conv'] += 1
op['meta'] = 'conv'
continue
if 'pool' in op['name']:
categorization['pool'] += 1
op['meta'] = 'pool'
continue
if 'backward' in op['name']:
categorization['backward'] += 1
op['meta'] = 'backward'
continue
if op['name'][0] == '_' and op['name'][1] != '_':
categorization['private'] += 1
op['meta'] = 'private'
continue
if 'batch_norm' in op['name']:
categorization['batch_norm'] += 1
op['meta'] = 'batch_norm'
continue
if 'Tensor' not in op['func'] or'Tensor' not in op['ret_type']:
categorization['non_tensor'] += 1
op['meta'] = 'non_tensor'
continue
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or 'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']:
categorization['backend'] += 1
op['meta'] = 'backend'
continue
if op['name'] in annotated_ops:
categorization['core'] += 1
op['meta'] = 'core ' + annotated_ops[op['name']]
else:
categorization['core'] += 1
op['meta'] = 'core unknown'
return categorization
for op in ops:
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops), op['full_name'] in vmap_ops]
print(','.join([str(i) for i in info]))
# categorization = annotate_ops(uniq_ops, True)
categorization = annotate_ops(ops, False)
with open(f"{analysis_name}", 'w') as f:
for op in ops:
info = [op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)] + [op['name'] in op_list for op_list in special_op_lists]
f.write(','.join([str(i) for i in info]) + '\n')
# Generates batching rule data
# gen_data([get_ops_for_key('FuncTorchBatched')], 'vmap')
if True:
with open('run_ops.txt', 'r') as f:
opinfo_ops = [i.strip() for i in f.readlines()]
with open('run_decompositions.txt', 'r') as f:
decomposed_ops = [i.strip() for i in f.readlines()]
gen_data([opinfo_ops, decomposed_ops], 'decompositions')