# Owner(s): ["module: cpp"] import os from cpp_api_parity import ( functional_impl_check, module_impl_check, sample_functional, sample_module, ) from cpp_api_parity.parity_table_parser import parse_parity_tracker_table from cpp_api_parity.utils import is_torch_nn_functional_test import torch import torch.testing._internal.common_nn as common_nn import torch.testing._internal.common_utils as common # NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose) PRINT_CPP_SOURCE = False devices = ["cpu", "cuda"] PARITY_TABLE_PATH = os.path.join( os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md" ) parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH) @torch.testing._internal.common_utils.markDynamoStrictTest class TestCppApiParity(common.TestCase): module_test_params_map = {} functional_test_params_map = {} expected_test_params_dicts = [] for test_params_dicts, test_instance_class in [ (sample_module.module_tests, common_nn.NewModuleTest), (sample_functional.functional_tests, common_nn.NewModuleTest), (common_nn.module_tests, common_nn.NewModuleTest), (common_nn.get_new_module_tests(), common_nn.NewModuleTest), (common_nn.criterion_tests, common_nn.CriterionTest), ]: for test_params_dict in test_params_dicts: if test_params_dict.get("test_cpp_api_parity", True): if is_torch_nn_functional_test(test_params_dict): functional_impl_check.write_test_to_test_class( TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices, ) else: module_impl_check.write_test_to_test_class( TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices, ) expected_test_params_dicts.append(test_params_dict) # Assert that all NN module/functional test dicts appear in the parity test assert len( [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name] ) == len(expected_test_params_dicts) * len(devices) # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`. # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) assert len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4 # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) assert ( len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name]) == 4 ) module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE) functional_impl_check.build_cpp_tests( TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE ) if __name__ == "__main__": common.TestCase._default_dtype_check_enabled = True common.run_tests()