mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757 Approved by: https://github.com/ezyang
94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
# Owner(s): ["module: cpp"]
|
|
|
|
|
|
import os
|
|
|
|
from cpp_api_parity import (
|
|
functional_impl_check,
|
|
module_impl_check,
|
|
sample_functional,
|
|
sample_module,
|
|
)
|
|
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
|
|
from cpp_api_parity.utils import is_torch_nn_functional_test
|
|
|
|
import torch
|
|
import torch.testing._internal.common_nn as common_nn
|
|
import torch.testing._internal.common_utils as common
|
|
|
|
|
|
# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
|
|
PRINT_CPP_SOURCE = False
|
|
|
|
devices = ["cpu", "cuda"]
|
|
|
|
PARITY_TABLE_PATH = os.path.join(
|
|
os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md"
|
|
)
|
|
|
|
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
|
|
|
|
|
|
@torch.testing._internal.common_utils.markDynamoStrictTest
|
|
class TestCppApiParity(common.TestCase):
|
|
module_test_params_map = {}
|
|
functional_test_params_map = {}
|
|
|
|
|
|
expected_test_params_dicts = []
|
|
|
|
if not common.IS_ARM64:
|
|
for test_params_dicts, test_instance_class in [
|
|
(sample_module.module_tests, common_nn.NewModuleTest),
|
|
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
|
(common_nn.module_tests, common_nn.NewModuleTest),
|
|
(common_nn.new_module_tests, common_nn.NewModuleTest),
|
|
(common_nn.criterion_tests, common_nn.CriterionTest),
|
|
]:
|
|
for test_params_dict in test_params_dicts:
|
|
if test_params_dict.get("test_cpp_api_parity", True):
|
|
if is_torch_nn_functional_test(test_params_dict):
|
|
functional_impl_check.write_test_to_test_class(
|
|
TestCppApiParity,
|
|
test_params_dict,
|
|
test_instance_class,
|
|
parity_table,
|
|
devices,
|
|
)
|
|
else:
|
|
module_impl_check.write_test_to_test_class(
|
|
TestCppApiParity,
|
|
test_params_dict,
|
|
test_instance_class,
|
|
parity_table,
|
|
devices,
|
|
)
|
|
expected_test_params_dicts.append(test_params_dict)
|
|
|
|
# Assert that all NN module/functional test dicts appear in the parity test
|
|
assert len(
|
|
[name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
|
|
) == len(expected_test_params_dicts) * len(devices)
|
|
|
|
# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
|
|
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
|
assert (
|
|
len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
|
|
)
|
|
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
|
assert (
|
|
len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
|
|
== 4
|
|
)
|
|
|
|
module_impl_check.build_cpp_tests(
|
|
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
|
|
)
|
|
functional_impl_check.build_cpp_tests(
|
|
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
common.TestCase._default_dtype_check_enabled = True
|
|
common.run_tests()
|