Files
pytorch/test/test_cpp_api_parity.py
Will Feng (FAIAR) 2fa3c1570d Refactor C++ API parity test mechanism and turn it on in CI again (#35190)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35190

The following are the main changes:
- The main logic of C++ API parity test mechanism is moved from `test/test_cpp_api_parity.py` to `test/cpp_api_parity/module_impl_check.py` and `test/cpp_api_parity/functional_impl_check.py`, so that there is a clear separation between module tests and functional tests, although they still share a lot of common utility functions which are all in `test/cpp_api_parity/utils.py`.
- Module init tests (i.e. testing whether C++ module accepts the same constructor options as the corresponding Python module) is removed and will be added again in the future.
- `cpp_constructor_args` / `cpp_options_args` / `cpp_function_call` are added as appropriate to all test params dict in `torch/testing/_internal/common_nn.py`, to indicate how to run C++ API parity test for this test params dict.

Test Plan: Imported from OSS

Differential Revision: D20588198

Pulled By: yf225

fbshipit-source-id: 11238c560c8247129584b9b49df73fff40c4d81d
2020-04-03 11:20:36 -07:00

61 lines
2.7 KiB
Python

import torch
# NN tests use double as the default dtype
torch.set_default_dtype(torch.double)
import os
import torch.testing._internal.common_utils as common
import torch.testing._internal.common_nn as common_nn
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
from cpp_api_parity.utils import is_torch_nn_functional_test
from cpp_api_parity import module_impl_check, functional_impl_check, sample_module, sample_functional
# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
PRINT_CPP_SOURCE = False
devices = ['cpu', 'cuda']
PARITY_TABLE_PATH = os.path.join(os.path.dirname(__file__), 'cpp_api_parity', 'parity-tracker.md')
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
class TestCppApiParity(common.TestCase):
module_test_params_map = {}
functional_test_params_map = {}
expected_test_params_dicts = []
for test_params_dicts, test_instance_class in [
(sample_module.module_tests, common_nn.ModuleTest),
(sample_functional.functional_tests, common_nn.NewModuleTest),
(common_nn.module_tests, common_nn.ModuleTest),
(common_nn.new_module_tests, common_nn.NewModuleTest),
(common_nn.criterion_tests, common_nn.CriterionTest),
(common_nn.new_criterion_tests, common_nn.NewCriterionTest),
]:
for test_params_dict in test_params_dicts:
if test_params_dict.get('test_cpp_api_parity', True):
if is_torch_nn_functional_test(test_params_dict):
functional_impl_check.write_test_to_test_class(
TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
else:
module_impl_check.write_test_to_test_class(
TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
expected_test_params_dicts.append(test_params_dict)
# Assert that all NN module/functional test dicts appear in the parity test
assert len([name for name in TestCppApiParity.__dict__ if 'test_torch_nn_' in name]) == \
len(expected_test_params_dicts) * len(devices)
# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
assert len([name for name in TestCppApiParity.__dict__ if 'SampleModule' in name]) == 4
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
assert len([name for name in TestCppApiParity.__dict__ if 'sample_functional' in name]) == 4
module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
functional_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
if __name__ == "__main__":
common.run_tests()