Files
pytorch/test/test_cpp_api_parity.py
PyTorch MergeBot 356ac3103a Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"
This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3.

Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
2025-08-04 20:37:39 +00:00

89 lines
3.0 KiB
Python

# Owner(s): ["module: cpp"]
import os
from cpp_api_parity import (
functional_impl_check,
module_impl_check,
sample_functional,
sample_module,
)
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
from cpp_api_parity.utils import is_torch_nn_functional_test
import torch
import torch.testing._internal.common_nn as common_nn
import torch.testing._internal.common_utils as common
# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
PRINT_CPP_SOURCE = False
devices = ["cpu", "cuda"]
PARITY_TABLE_PATH = os.path.join(
os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md"
)
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCppApiParity(common.TestCase):
module_test_params_map = {}
functional_test_params_map = {}
expected_test_params_dicts = []
for test_params_dicts, test_instance_class in [
(sample_module.module_tests, common_nn.NewModuleTest),
(sample_functional.functional_tests, common_nn.NewModuleTest),
(common_nn.module_tests, common_nn.NewModuleTest),
(common_nn.get_new_module_tests(), common_nn.NewModuleTest),
(common_nn.criterion_tests, common_nn.CriterionTest),
]:
for test_params_dict in test_params_dicts:
if test_params_dict.get("test_cpp_api_parity", True):
if is_torch_nn_functional_test(test_params_dict):
functional_impl_check.write_test_to_test_class(
TestCppApiParity,
test_params_dict,
test_instance_class,
parity_table,
devices,
)
else:
module_impl_check.write_test_to_test_class(
TestCppApiParity,
test_params_dict,
test_instance_class,
parity_table,
devices,
)
expected_test_params_dicts.append(test_params_dict)
# Assert that all NN module/functional test dicts appear in the parity test
assert len(
[name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
) == len(expected_test_params_dicts) * len(devices)
# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
assert len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
assert (
len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
== 4
)
module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
functional_impl_check.build_cpp_tests(
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
)
if __name__ == "__main__":
common.TestCase._default_dtype_check_enabled = True
common.run_tests()