Files
pytorch/test/cpp_api_parity/parity_table_parser.py
Will Feng (FAIAR) 2fa3c1570d Refactor C++ API parity test mechanism and turn it on in CI again (#35190)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35190

The following are the main changes:
- The main logic of C++ API parity test mechanism is moved from `test/test_cpp_api_parity.py` to `test/cpp_api_parity/module_impl_check.py` and `test/cpp_api_parity/functional_impl_check.py`, so that there is a clear separation between module tests and functional tests, although they still share a lot of common utility functions which are all in `test/cpp_api_parity/utils.py`.
- Module init tests (i.e. testing whether C++ module accepts the same constructor options as the corresponding Python module) is removed and will be added again in the future.
- `cpp_constructor_args` / `cpp_options_args` / `cpp_function_call` are added as appropriate to all test params dict in `torch/testing/_internal/common_nn.py`, to indicate how to run C++ API parity test for this test params dict.

Test Plan: Imported from OSS

Differential Revision: D20588198

Pulled By: yf225

fbshipit-source-id: 11238c560c8247129584b9b49df73fff40c4d81d
2020-04-03 11:20:36 -07:00

59 lines
1.8 KiB
Python

from collections import namedtuple
ParityStatus = namedtuple('ParityStatus', ['has_impl_parity', 'has_doc_parity'])
'''
This function expects the parity tracker Markdown file to have the following format:
```
## package1_name
API | Implementation Parity | Doc Parity
------------- | ------------- | -------------
API_Name|No|No
...
## package2_name
API | Implementation Parity | Doc Parity
------------- | ------------- | -------------
API_Name|No|No
...
```
The returned dict has the following format:
```
Dict[package_name]
-> Dict[api_name]
-> ParityStatus
```
'''
def parse_parity_tracker_table(file_path):
def parse_parity_choice(str):
if str in ['Yes', 'No']:
return str == 'Yes'
else:
raise RuntimeError(
'{} is not a supported parity choice. The valid choices are "Yes" and "No".'.format(str))
parity_tracker_dict = {}
with open(file_path, 'r') as f:
all_text = f.read()
packages = all_text.split('##')
for package in packages[1:]:
lines = [line.strip() for line in package.split('\n') if line.strip() != '']
package_name = lines[0]
if package_name in parity_tracker_dict:
raise RuntimeError("Duplicated package name `{}` found in {}".format(package_name, file_path))
else:
parity_tracker_dict[package_name] = {}
for api_status in lines[3:]:
api_name, has_impl_parity_str, has_doc_parity_str = [x.strip() for x in api_status.split('|')]
parity_tracker_dict[package_name][api_name] = ParityStatus(
has_impl_parity=parse_parity_choice(has_impl_parity_str),
has_doc_parity=parse_parity_choice(has_doc_parity_str))
return parity_tracker_dict