mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT on test/cpp_api_parity
, test/cpp_extensions
, test/create_dummy_torchscript_model.py
, test/custom_backend
, test/custom_operator
(#123518)
Partially addresses #123062 Ran lintrunner on: - `test/cpp_api_parity` - `test/cpp_extensions` - `test/create_dummy_torchscript_model.py` - `test/custom_backend` - `test/custom_operator` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123518 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
b279034e5a
commit
c797fbc4e1
@ -14,26 +14,39 @@
|
||||
# 4. Compare Python/C++ functional's forward output. If they are the same, then we
|
||||
# have implementation parity between Python/C++ module.
|
||||
|
||||
import os
|
||||
import pprint
|
||||
import re
|
||||
import tempfile
|
||||
from string import Template
|
||||
import re
|
||||
import pprint
|
||||
import os
|
||||
|
||||
import torch
|
||||
from cpp_api_parity.utils import TorchNNFunctionalTestParams, TORCH_NN_COMMON_TEST_HARNESS, \
|
||||
compile_cpp_code_inline, set_python_tensors_requires_grad, move_python_tensors_to_device, \
|
||||
add_test, compute_cpp_args_construction_stmts_and_forward_arg_symbols, serialize_arg_dict_as_script_module, \
|
||||
compute_arg_dict, decorate_test_fn, compute_temp_file_path, generate_error_msg, is_torch_nn_functional_test, \
|
||||
try_remove_folder
|
||||
|
||||
from cpp_api_parity.sample_functional import SAMPLE_FUNCTIONAL_CPP_SOURCE
|
||||
from cpp_api_parity.utils import (
|
||||
add_test,
|
||||
compile_cpp_code_inline,
|
||||
compute_arg_dict,
|
||||
compute_cpp_args_construction_stmts_and_forward_arg_symbols,
|
||||
compute_temp_file_path,
|
||||
decorate_test_fn,
|
||||
generate_error_msg,
|
||||
is_torch_nn_functional_test,
|
||||
move_python_tensors_to_device,
|
||||
serialize_arg_dict_as_script_module,
|
||||
set_python_tensors_requires_grad,
|
||||
TORCH_NN_COMMON_TEST_HARNESS,
|
||||
TorchNNFunctionalTestParams,
|
||||
try_remove_folder,
|
||||
)
|
||||
|
||||
# Expected substitutions:
|
||||
#
|
||||
# ${functional_variant_name} (e.g. `BCELoss_no_reduce`)
|
||||
# ${cpp_args_construction_stmts}
|
||||
# ${cpp_function_call}
|
||||
TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template("""
|
||||
TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template(
|
||||
"""
|
||||
void ${functional_variant_name}_test_forward(
|
||||
const std::string& arg_dict_file_path,
|
||||
const std::string& forward_output_file_path) {
|
||||
@ -56,17 +69,24 @@ void ${functional_variant_name}_test_forward(
|
||||
// Save the output into a file to be compared in Python later
|
||||
write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path);
|
||||
}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def run_forward(unit_test_class, test_params):
|
||||
device = test_params.device
|
||||
|
||||
inputs = set_python_tensors_requires_grad(move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict['input']], device))
|
||||
inputs = set_python_tensors_requires_grad(
|
||||
move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict["input"]], device
|
||||
)
|
||||
)
|
||||
inputs += move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict['target']], device)
|
||||
[arg_value for _, arg_value in test_params.arg_dict["target"]], device
|
||||
)
|
||||
inputs += move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict['extra_args']], device)
|
||||
[arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device
|
||||
)
|
||||
|
||||
# Some functionals (such as `F.rrelu`) create random tensors in their call path.
|
||||
# To make sure the random tensors created are the same in Python/C++, we need
|
||||
@ -76,6 +96,7 @@ def run_forward(unit_test_class, test_params):
|
||||
|
||||
return python_output
|
||||
|
||||
|
||||
def test_forward(unit_test_class, test_params):
|
||||
functional_variant_name = test_params.functional_variant_name
|
||||
cpp_tmp_folder = test_params.cpp_tmp_folder
|
||||
@ -87,105 +108,149 @@ def test_forward(unit_test_class, test_params):
|
||||
python_output = run_forward(unit_test_class, test_params)
|
||||
|
||||
# Save Python arguments to be used from C++ function
|
||||
arg_dict_file_path = compute_temp_file_path(cpp_tmp_folder, functional_variant_name, 'arg_dict')
|
||||
arg_dict_file_path = compute_temp_file_path(
|
||||
cpp_tmp_folder, functional_variant_name, "arg_dict"
|
||||
)
|
||||
serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path)
|
||||
|
||||
cpp_test_name = f'{test_params.functional_variant_name}_test_forward'
|
||||
cpp_test_fn = getattr(unit_test_class.functional_impl_check_cpp_module, cpp_test_name)
|
||||
cpp_test_name = f"{test_params.functional_variant_name}_test_forward"
|
||||
cpp_test_fn = getattr(
|
||||
unit_test_class.functional_impl_check_cpp_module, cpp_test_name
|
||||
)
|
||||
|
||||
def run_cpp_test_fn_and_check_output():
|
||||
forward_output_file_path = compute_temp_file_path(cpp_tmp_folder, functional_variant_name, 'forward_output')
|
||||
forward_output_file_path = compute_temp_file_path(
|
||||
cpp_tmp_folder, functional_variant_name, "forward_output"
|
||||
)
|
||||
|
||||
cpp_test_fn(arg_dict_file_path, forward_output_file_path)
|
||||
cpp_output = torch.load(forward_output_file_path)
|
||||
|
||||
# Check that forward outputs are equal
|
||||
unit_test_class.assertEqual(
|
||||
python_output, cpp_output,
|
||||
msg=generate_error_msg("forward output", cpp_output, python_output))
|
||||
python_output,
|
||||
cpp_output,
|
||||
msg=generate_error_msg("forward output", cpp_output, python_output),
|
||||
)
|
||||
|
||||
run_cpp_test_fn_and_check_output()
|
||||
|
||||
# Remove temporary folder that stores C++ outputs
|
||||
try_remove_folder(cpp_tmp_folder)
|
||||
|
||||
|
||||
def compute_functional_name(test_params_dict):
|
||||
def camel_case_to_snake_case(camel_case_str):
|
||||
return re.sub(r'(?<!^)(?=[A-Z])', '_', camel_case_str).lower()
|
||||
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_str).lower()
|
||||
|
||||
if 'cpp_options_args' in test_params_dict:
|
||||
if "cpp_options_args" in test_params_dict:
|
||||
# Expected format for `cpp_options_args`: `F::FunctionalFuncOptions(...)`
|
||||
# Example output: `binary_cross_entropy`
|
||||
return camel_case_to_snake_case(
|
||||
test_params_dict['cpp_options_args'].split('(')[0].replace('F::', '').replace('FuncOptions', ''))
|
||||
elif 'cpp_function_call' in test_params_dict:
|
||||
test_params_dict["cpp_options_args"]
|
||||
.split("(")[0]
|
||||
.replace("F::", "")
|
||||
.replace("FuncOptions", "")
|
||||
)
|
||||
elif "cpp_function_call" in test_params_dict:
|
||||
# Expected format for `cpp_function_call`: `F::functional_name(...)`
|
||||
# Example output: `binary_cross_entropy`
|
||||
return test_params_dict['cpp_function_call'].split('(')[0].replace('F::', '')
|
||||
return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{}".format(
|
||||
pprint.pformat(test_params_dict)))
|
||||
pprint.pformat(test_params_dict)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
|
||||
if 'cpp_function_call' in test_params_dict:
|
||||
return test_params_dict['cpp_function_call']
|
||||
elif 'cpp_options_args' in test_params_dict:
|
||||
cpp_forward_args_symbols = [arg_name for arg_name, _ in
|
||||
arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args']]
|
||||
return 'F::{}({}, {})'.format(
|
||||
functional_name, ", ".join(cpp_forward_args_symbols), test_params_dict['cpp_options_args'])
|
||||
if "cpp_function_call" in test_params_dict:
|
||||
return test_params_dict["cpp_function_call"]
|
||||
elif "cpp_options_args" in test_params_dict:
|
||||
cpp_forward_args_symbols = [
|
||||
arg_name
|
||||
for arg_name, _ in arg_dict["input"]
|
||||
+ arg_dict["target"]
|
||||
+ arg_dict["extra_args"]
|
||||
]
|
||||
return "F::{}({}, {})".format(
|
||||
functional_name,
|
||||
", ".join(cpp_forward_args_symbols),
|
||||
test_params_dict["cpp_options_args"],
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{}".format(
|
||||
pprint.pformat(test_params_dict)))
|
||||
pprint.pformat(test_params_dict)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def process_test_params_for_functional(test_params_dict, device, test_instance_class):
|
||||
test_instance = test_instance_class(**test_params_dict)
|
||||
functional_name = compute_functional_name(test_params_dict)
|
||||
assert test_instance.get_name().startswith('test_')
|
||||
assert test_instance.get_name().startswith("test_")
|
||||
# Example output: `BCELoss_no_reduce_cuda`
|
||||
functional_variant_name = test_instance.get_name()[5:] + (('_' + device) if device != 'cpu' else '')
|
||||
functional_variant_name = test_instance.get_name()[5:] + (
|
||||
("_" + device) if device != "cpu" else ""
|
||||
)
|
||||
arg_dict = compute_arg_dict(test_params_dict, test_instance)
|
||||
|
||||
return TorchNNFunctionalTestParams(
|
||||
functional_name=functional_name,
|
||||
functional_variant_name=functional_variant_name,
|
||||
test_instance=test_instance,
|
||||
cpp_function_call=compute_cpp_function_call(test_params_dict, arg_dict, functional_name),
|
||||
cpp_function_call=compute_cpp_function_call(
|
||||
test_params_dict, arg_dict, functional_name
|
||||
),
|
||||
arg_dict=arg_dict,
|
||||
has_parity=test_params_dict.get('has_parity', True),
|
||||
has_parity=test_params_dict.get("has_parity", True),
|
||||
device=device,
|
||||
cpp_tmp_folder=tempfile.mkdtemp(),
|
||||
)
|
||||
|
||||
|
||||
def write_test_to_test_class(
|
||||
unit_test_class, test_params_dict, test_instance_class, parity_table, devices):
|
||||
unit_test_class, test_params_dict, test_instance_class, parity_table, devices
|
||||
):
|
||||
assert is_torch_nn_functional_test(test_params_dict)
|
||||
|
||||
assert 'cpp_options_args' in test_params_dict or 'cpp_function_call' in test_params_dict, (
|
||||
assert (
|
||||
"cpp_options_args" in test_params_dict
|
||||
or "cpp_function_call" in test_params_dict
|
||||
), (
|
||||
"To enable C++ API parity test, "
|
||||
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{}. \n"
|
||||
"If you are interested in adding the C++ API parity test, please see:\n"
|
||||
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
|
||||
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
|
||||
).format(pprint.pformat(test_params_dict))
|
||||
).format(
|
||||
pprint.pformat(test_params_dict)
|
||||
)
|
||||
|
||||
assert not ('cpp_options_args' in test_params_dict and 'cpp_function_call' in test_params_dict), (
|
||||
assert not (
|
||||
"cpp_options_args" in test_params_dict
|
||||
and "cpp_function_call" in test_params_dict
|
||||
), (
|
||||
"Only one of `cpp_options_args` and `cpp_function_call` entries "
|
||||
f"should be present in test params dict:\n{pprint.pformat(test_params_dict)}")
|
||||
f"should be present in test params dict:\n{pprint.pformat(test_params_dict)}"
|
||||
)
|
||||
|
||||
functional_name = compute_functional_name(test_params_dict)
|
||||
|
||||
assert hasattr(torch.nn.functional, functional_name), \
|
||||
"`torch.nn.functional` doesn't have function `{}`. (Discovered while processing\n{}.)".format(
|
||||
functional_name, pprint.pformat(test_params_dict))
|
||||
assert hasattr(
|
||||
torch.nn.functional, functional_name
|
||||
), "`torch.nn.functional` doesn't have function `{}`. (Discovered while processing\n{}.)".format(
|
||||
functional_name, pprint.pformat(test_params_dict)
|
||||
)
|
||||
|
||||
functional_full_name = 'F::' + functional_name
|
||||
functional_full_name = "F::" + functional_name
|
||||
|
||||
assert functional_full_name in parity_table['torch::nn::functional'], (
|
||||
assert functional_full_name in parity_table["torch::nn::functional"], (
|
||||
"Please add `{}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. "
|
||||
"(Discovered while processing\n{}.)").format(functional_full_name, pprint.pformat(test_params_dict))
|
||||
"(Discovered while processing\n{}.)"
|
||||
).format(functional_full_name, pprint.pformat(test_params_dict))
|
||||
|
||||
for device in devices:
|
||||
test_params = process_test_params_for_functional(
|
||||
@ -194,24 +259,37 @@ def write_test_to_test_class(
|
||||
test_instance_class=test_instance_class,
|
||||
)
|
||||
try_remove_folder(test_params.cpp_tmp_folder)
|
||||
unit_test_name = f'test_torch_nn_functional_{test_params.functional_variant_name}'
|
||||
unit_test_name = (
|
||||
f"test_torch_nn_functional_{test_params.functional_variant_name}"
|
||||
)
|
||||
unit_test_class.functional_test_params_map[unit_test_name] = test_params
|
||||
|
||||
def test_fn(self):
|
||||
test_forward(
|
||||
unit_test_class=self, test_params=unit_test_class.functional_test_params_map[self._testMethodName])
|
||||
unit_test_class=self,
|
||||
test_params=unit_test_class.functional_test_params_map[
|
||||
self._testMethodName
|
||||
],
|
||||
)
|
||||
|
||||
test_fn = decorate_test_fn(
|
||||
test_fn=test_fn,
|
||||
test_cuda=test_params_dict.get('test_cuda', True),
|
||||
has_impl_parity=parity_table['torch::nn::functional'][functional_full_name][0] and
|
||||
test_params_dict.get('has_parity', True),
|
||||
device=device)
|
||||
test_cuda=test_params_dict.get("test_cuda", True),
|
||||
has_impl_parity=parity_table["torch::nn::functional"][functional_full_name][
|
||||
0
|
||||
]
|
||||
and test_params_dict.get("has_parity", True),
|
||||
device=device,
|
||||
)
|
||||
|
||||
add_test(unit_test_class, unit_test_name, test_fn)
|
||||
|
||||
|
||||
def generate_test_cpp_sources(test_params, template):
|
||||
cpp_args_construction_stmts, _ = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
|
||||
(
|
||||
cpp_args_construction_stmts,
|
||||
_,
|
||||
) = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
|
||||
|
||||
test_cpp_sources = template.substitute(
|
||||
functional_variant_name=test_params.functional_variant_name,
|
||||
@ -220,19 +298,21 @@ def generate_test_cpp_sources(test_params, template):
|
||||
)
|
||||
return test_cpp_sources
|
||||
|
||||
|
||||
# Build all C++ tests together, instead of once per test.
|
||||
def build_cpp_tests(unit_test_class, print_cpp_source=False):
|
||||
assert len(unit_test_class.functional_test_params_map) > 0
|
||||
cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_FUNCTIONAL_CPP_SOURCE
|
||||
functions = []
|
||||
for test_params in unit_test_class.functional_test_params_map.values():
|
||||
cpp_sources += generate_test_cpp_sources(test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD)
|
||||
functions.append(f'{test_params.functional_variant_name}_test_forward')
|
||||
cpp_sources += generate_test_cpp_sources(
|
||||
test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD
|
||||
)
|
||||
functions.append(f"{test_params.functional_variant_name}_test_forward")
|
||||
if print_cpp_source:
|
||||
print(cpp_sources)
|
||||
|
||||
cpp_module = compile_cpp_code_inline(
|
||||
name='functional_impl_check',
|
||||
cpp_sources=cpp_sources,
|
||||
functions=functions)
|
||||
name="functional_impl_check", cpp_sources=cpp_sources, functions=functions
|
||||
)
|
||||
unit_test_class.functional_impl_check_cpp_module = cpp_module
|
||||
|
@ -14,19 +14,31 @@
|
||||
# 4. Compare Python/C++ module's forward output and backward gradients. If they
|
||||
# are the same, then we have implementation parity between Python/C++ module.
|
||||
|
||||
import tempfile
|
||||
from string import Template
|
||||
import types
|
||||
import pprint
|
||||
import os
|
||||
import pprint
|
||||
import tempfile
|
||||
import types
|
||||
from string import Template
|
||||
|
||||
import torch
|
||||
from cpp_api_parity.utils import TorchNNModuleTestParams, TORCH_NN_COMMON_TEST_HARNESS, \
|
||||
compile_cpp_code_inline, set_python_tensors_requires_grad, move_python_tensors_to_device, \
|
||||
add_test, compute_cpp_args_construction_stmts_and_forward_arg_symbols, serialize_arg_dict_as_script_module, \
|
||||
compute_arg_dict, decorate_test_fn, compute_temp_file_path, generate_error_msg, is_torch_nn_functional_test, \
|
||||
try_remove_folder
|
||||
|
||||
from cpp_api_parity.sample_module import SAMPLE_MODULE_CPP_SOURCE
|
||||
from cpp_api_parity.utils import (
|
||||
add_test,
|
||||
compile_cpp_code_inline,
|
||||
compute_arg_dict,
|
||||
compute_cpp_args_construction_stmts_and_forward_arg_symbols,
|
||||
compute_temp_file_path,
|
||||
decorate_test_fn,
|
||||
generate_error_msg,
|
||||
is_torch_nn_functional_test,
|
||||
move_python_tensors_to_device,
|
||||
serialize_arg_dict_as_script_module,
|
||||
set_python_tensors_requires_grad,
|
||||
TORCH_NN_COMMON_TEST_HARNESS,
|
||||
TorchNNModuleTestParams,
|
||||
try_remove_folder,
|
||||
)
|
||||
|
||||
# Expected substitutions:
|
||||
#
|
||||
@ -36,7 +48,8 @@ from cpp_api_parity.sample_module import SAMPLE_MODULE_CPP_SOURCE
|
||||
# ${cpp_constructor_args}
|
||||
# ${device}
|
||||
# ${cpp_forward_args_symbols}
|
||||
TORCH_NN_MODULE_TEST_FORWARD_BACKWARD = Template("""
|
||||
TORCH_NN_MODULE_TEST_FORWARD_BACKWARD = Template(
|
||||
"""
|
||||
void ${module_variant_name}_test_forward_backward(
|
||||
const std::string& arg_dict_file_path,
|
||||
const std::string& module_file_path,
|
||||
@ -85,18 +98,27 @@ void ${module_variant_name}_test_forward_backward(
|
||||
|
||||
write_ivalue_to_file(torch::IValue(grad_dict), backward_grad_dict_file_path);
|
||||
}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def run_python_forward_backward(unit_test_class, test_params):
|
||||
device = test_params.device
|
||||
module = test_params.test_instance.constructor(*test_params.test_instance.constructor_args).to(device)
|
||||
module = test_params.test_instance.constructor(
|
||||
*test_params.test_instance.constructor_args
|
||||
).to(device)
|
||||
|
||||
inputs = set_python_tensors_requires_grad(move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict['input']], device))
|
||||
inputs = set_python_tensors_requires_grad(
|
||||
move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict["input"]], device
|
||||
)
|
||||
)
|
||||
inputs += move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict['target']], device)
|
||||
[arg_value for _, arg_value in test_params.arg_dict["target"]], device
|
||||
)
|
||||
inputs += move_python_tensors_to_device(
|
||||
[arg_value for _, arg_value in test_params.arg_dict['extra_args']], device)
|
||||
[arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device
|
||||
)
|
||||
|
||||
# Some modules (such as `RReLU`) create random tensors in their forward pass.
|
||||
# To make sure the random tensors created are the same in Python/C++, we need
|
||||
@ -130,6 +152,7 @@ def run_python_forward_backward(unit_test_class, test_params):
|
||||
|
||||
return script_module, python_output, python_grad_dict
|
||||
|
||||
|
||||
def test_forward_backward(unit_test_class, test_params):
|
||||
module_variant_name = test_params.module_variant_name
|
||||
cpp_tmp_folder = test_params.cpp_tmp_folder
|
||||
@ -138,76 +161,113 @@ def test_forward_backward(unit_test_class, test_params):
|
||||
os.mkdir(cpp_tmp_folder)
|
||||
|
||||
# Run forward and backward on Python module
|
||||
script_module, python_output, python_grad_dict = run_python_forward_backward(unit_test_class, test_params)
|
||||
script_module, python_output, python_grad_dict = run_python_forward_backward(
|
||||
unit_test_class, test_params
|
||||
)
|
||||
|
||||
# Save Python module and arguments to be used from C++ function
|
||||
module_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'module')
|
||||
arg_dict_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'arg_dict')
|
||||
module_file_path = compute_temp_file_path(
|
||||
cpp_tmp_folder, module_variant_name, "module"
|
||||
)
|
||||
arg_dict_file_path = compute_temp_file_path(
|
||||
cpp_tmp_folder, module_variant_name, "arg_dict"
|
||||
)
|
||||
script_module.save(module_file_path)
|
||||
serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path)
|
||||
|
||||
cpp_test_name = f'{test_params.module_variant_name}_test_forward_backward'
|
||||
cpp_test_name = f"{test_params.module_variant_name}_test_forward_backward"
|
||||
cpp_test_fn = getattr(unit_test_class.module_impl_check_cpp_module, cpp_test_name)
|
||||
|
||||
def run_cpp_test_fn_and_check_output():
|
||||
forward_output_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'forward_output')
|
||||
backward_grad_dict_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'backward_grad_dict')
|
||||
forward_output_file_path = compute_temp_file_path(
|
||||
cpp_tmp_folder, module_variant_name, "forward_output"
|
||||
)
|
||||
backward_grad_dict_file_path = compute_temp_file_path(
|
||||
cpp_tmp_folder, module_variant_name, "backward_grad_dict"
|
||||
)
|
||||
|
||||
cpp_test_fn(arg_dict_file_path, module_file_path, forward_output_file_path, backward_grad_dict_file_path)
|
||||
cpp_test_fn(
|
||||
arg_dict_file_path,
|
||||
module_file_path,
|
||||
forward_output_file_path,
|
||||
backward_grad_dict_file_path,
|
||||
)
|
||||
cpp_output = torch.load(forward_output_file_path)
|
||||
cpp_grad_dict = torch.load(backward_grad_dict_file_path)
|
||||
|
||||
# Check that forward outputs are equal
|
||||
unit_test_class.assertEqual(python_output, cpp_output,
|
||||
msg=generate_error_msg("forward output", cpp_output, python_output))
|
||||
unit_test_class.assertEqual(
|
||||
python_output,
|
||||
cpp_output,
|
||||
msg=generate_error_msg("forward output", cpp_output, python_output),
|
||||
)
|
||||
|
||||
# Check that module parameter gradients are equal after backward pass
|
||||
unit_test_class.assertEqual(
|
||||
len(python_grad_dict), len(cpp_grad_dict),
|
||||
msg=generate_error_msg("# of parameters", len(cpp_grad_dict), len(python_grad_dict)))
|
||||
len(python_grad_dict),
|
||||
len(cpp_grad_dict),
|
||||
msg=generate_error_msg(
|
||||
"# of parameters", len(cpp_grad_dict), len(python_grad_dict)
|
||||
),
|
||||
)
|
||||
for key in python_grad_dict:
|
||||
param_name = None
|
||||
for suffix in ['_grad', '_grad_indices', '_grad_values']:
|
||||
for suffix in ["_grad", "_grad_indices", "_grad_values"]:
|
||||
if key.endswith(suffix):
|
||||
param_name = key[:-len(suffix)]
|
||||
param_name = key[: -len(suffix)]
|
||||
break
|
||||
assert param_name is not None
|
||||
sparsity_str = 'sparse' if key.endswith(('_grad_indices', '_grad_values')) else 'dense'
|
||||
sparsity_str = (
|
||||
"sparse" if key.endswith(("_grad_indices", "_grad_values")) else "dense"
|
||||
)
|
||||
|
||||
unit_test_class.assertTrue(
|
||||
key in cpp_grad_dict,
|
||||
msg=generate_error_msg(
|
||||
f"\"Does module have a parameter named `{param_name}` with {sparsity_str} gradient?\"",
|
||||
False, True))
|
||||
f'"Does module have a parameter named `{param_name}` with {sparsity_str} gradient?"',
|
||||
False,
|
||||
True,
|
||||
),
|
||||
)
|
||||
unit_test_class.assertEqual(
|
||||
python_grad_dict[key], cpp_grad_dict[key],
|
||||
python_grad_dict[key],
|
||||
cpp_grad_dict[key],
|
||||
msg=generate_error_msg(
|
||||
f"`{param_name}`'s {sparsity_str} gradient (`{key}`)",
|
||||
cpp_grad_dict[key], python_grad_dict[key]))
|
||||
cpp_grad_dict[key],
|
||||
python_grad_dict[key],
|
||||
),
|
||||
)
|
||||
|
||||
run_cpp_test_fn_and_check_output()
|
||||
|
||||
# Remove temporary folder that stores C++ outputs
|
||||
try_remove_folder(cpp_tmp_folder)
|
||||
|
||||
|
||||
def compute_module_name(test_params_dict):
|
||||
fullname = test_params_dict.get('fullname', None)
|
||||
fullname = test_params_dict.get("fullname", None)
|
||||
if fullname:
|
||||
module_name = fullname.split('_')[0]
|
||||
module_name = fullname.split("_")[0]
|
||||
else:
|
||||
module_name = test_params_dict.get('module_name')
|
||||
module_name = test_params_dict.get("module_name")
|
||||
return module_name
|
||||
|
||||
|
||||
def process_test_params_for_module(test_params_dict, device, test_instance_class):
|
||||
module_name = compute_module_name(test_params_dict)
|
||||
test_params_dict['constructor'] = test_params_dict.get('constructor', getattr(torch.nn, module_name))
|
||||
test_params_dict["constructor"] = test_params_dict.get(
|
||||
"constructor", getattr(torch.nn, module_name)
|
||||
)
|
||||
test_instance = test_instance_class(**test_params_dict)
|
||||
assert test_instance.get_name().startswith('test_')
|
||||
assert test_instance.get_name().startswith("test_")
|
||||
# Example output: `BCELoss_weights_cuda`
|
||||
module_variant_name = test_instance.get_name()[5:] + (('_' + device) if device != 'cpu' else '')
|
||||
module_variant_name = test_instance.get_name()[5:] + (
|
||||
("_" + device) if device != "cpu" else ""
|
||||
)
|
||||
|
||||
if 'constructor_args' in test_params_dict:
|
||||
assert 'cpp_constructor_args' in test_params_dict, (
|
||||
if "constructor_args" in test_params_dict:
|
||||
assert "cpp_constructor_args" in test_params_dict, (
|
||||
"If `constructor_args` is present in test params dict, to enable C++ API parity test, "
|
||||
f"`cpp_constructor_args` must be present in:\n{pprint.pformat(test_params_dict)}"
|
||||
"If you are interested in adding the C++ API parity test, please see:\n"
|
||||
@ -219,15 +279,17 @@ def process_test_params_for_module(test_params_dict, device, test_instance_class
|
||||
module_name=module_name,
|
||||
module_variant_name=module_variant_name,
|
||||
test_instance=test_instance,
|
||||
cpp_constructor_args=test_params_dict.get('cpp_constructor_args', ''),
|
||||
cpp_constructor_args=test_params_dict.get("cpp_constructor_args", ""),
|
||||
arg_dict=compute_arg_dict(test_params_dict, test_instance),
|
||||
has_parity=test_params_dict.get('has_parity', True),
|
||||
has_parity=test_params_dict.get("has_parity", True),
|
||||
device=device,
|
||||
cpp_tmp_folder=tempfile.mkdtemp(),
|
||||
)
|
||||
|
||||
|
||||
def write_test_to_test_class(
|
||||
unit_test_class, test_params_dict, test_instance_class, parity_table, devices):
|
||||
unit_test_class, test_params_dict, test_instance_class, parity_table, devices
|
||||
):
|
||||
assert not is_torch_nn_functional_test(test_params_dict)
|
||||
|
||||
module_name = compute_module_name(test_params_dict)
|
||||
@ -238,11 +300,12 @@ def write_test_to_test_class(
|
||||
f"or set `module_name` using format `ModuleName` in the module test dict:\n{pprint.pformat(test_params_dict)}"
|
||||
)
|
||||
|
||||
module_full_name = 'torch::nn::' + module_name
|
||||
module_full_name = "torch::nn::" + module_name
|
||||
|
||||
assert module_full_name in parity_table['torch::nn'], (
|
||||
assert module_full_name in parity_table["torch::nn"], (
|
||||
f"Please add `{module_full_name}` entry to `torch::nn` section of `test/cpp_api_parity/parity-tracker.md`. "
|
||||
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)")
|
||||
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
|
||||
)
|
||||
|
||||
for device in devices:
|
||||
test_params = process_test_params_for_module(
|
||||
@ -251,35 +314,43 @@ def write_test_to_test_class(
|
||||
test_instance_class=test_instance_class,
|
||||
)
|
||||
try_remove_folder(test_params.cpp_tmp_folder)
|
||||
unit_test_name = f'test_torch_nn_{test_params.module_variant_name}'
|
||||
unit_test_name = f"test_torch_nn_{test_params.module_variant_name}"
|
||||
unit_test_class.module_test_params_map[unit_test_name] = test_params
|
||||
|
||||
def test_fn(self):
|
||||
test_forward_backward(
|
||||
unit_test_class=self, test_params=unit_test_class.module_test_params_map[self._testMethodName])
|
||||
unit_test_class=self,
|
||||
test_params=unit_test_class.module_test_params_map[
|
||||
self._testMethodName
|
||||
],
|
||||
)
|
||||
|
||||
test_fn = decorate_test_fn(
|
||||
test_fn=test_fn,
|
||||
test_cuda=test_params_dict.get('test_cuda', True),
|
||||
has_impl_parity=parity_table['torch::nn'][module_full_name][0] and
|
||||
test_params_dict.get('has_parity', True),
|
||||
device=device)
|
||||
test_cuda=test_params_dict.get("test_cuda", True),
|
||||
has_impl_parity=parity_table["torch::nn"][module_full_name][0]
|
||||
and test_params_dict.get("has_parity", True),
|
||||
device=device,
|
||||
)
|
||||
|
||||
add_test(unit_test_class, unit_test_name, test_fn)
|
||||
|
||||
|
||||
def generate_test_cpp_sources(test_params, template):
|
||||
device = test_params.device
|
||||
|
||||
cpp_constructor_args = test_params.cpp_constructor_args
|
||||
if cpp_constructor_args != '':
|
||||
cpp_constructor_args = f'({cpp_constructor_args})'
|
||||
if cpp_constructor_args != "":
|
||||
cpp_constructor_args = f"({cpp_constructor_args})"
|
||||
|
||||
cpp_args_construction_stmts, cpp_forward_args_symbols = \
|
||||
compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
|
||||
(
|
||||
cpp_args_construction_stmts,
|
||||
cpp_forward_args_symbols,
|
||||
) = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
|
||||
|
||||
test_cpp_sources = template.substitute(
|
||||
module_variant_name=test_params.module_variant_name,
|
||||
module_qualified_name=f'torch::nn::{test_params.module_name}',
|
||||
module_qualified_name=f"torch::nn::{test_params.module_name}",
|
||||
cpp_args_construction_stmts=";\n ".join(cpp_args_construction_stmts),
|
||||
cpp_constructor_args=cpp_constructor_args,
|
||||
cpp_forward_args_symbols=", ".join(cpp_forward_args_symbols),
|
||||
@ -287,6 +358,7 @@ def generate_test_cpp_sources(test_params, template):
|
||||
)
|
||||
return test_cpp_sources
|
||||
|
||||
|
||||
# Build all C++ tests together, instead of once per test.
|
||||
def build_cpp_tests(unit_test_class, print_cpp_source=False):
|
||||
assert len(unit_test_class.module_test_params_map) > 0
|
||||
@ -294,13 +366,13 @@ def build_cpp_tests(unit_test_class, print_cpp_source=False):
|
||||
functions = []
|
||||
for test_params in unit_test_class.module_test_params_map.values():
|
||||
cpp_sources += generate_test_cpp_sources(
|
||||
test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD_BACKWARD)
|
||||
functions.append(f'{test_params.module_variant_name}_test_forward_backward')
|
||||
test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD_BACKWARD
|
||||
)
|
||||
functions.append(f"{test_params.module_variant_name}_test_forward_backward")
|
||||
if print_cpp_source:
|
||||
print(cpp_sources)
|
||||
|
||||
cpp_module = compile_cpp_code_inline(
|
||||
name='module_impl_check',
|
||||
cpp_sources=cpp_sources,
|
||||
functions=functions)
|
||||
name="module_impl_check", cpp_sources=cpp_sources, functions=functions
|
||||
)
|
||||
unit_test_class.module_impl_check_cpp_module = cpp_module
|
||||
|
@ -1,8 +1,8 @@
|
||||
from collections import namedtuple
|
||||
|
||||
ParityStatus = namedtuple('ParityStatus', ['has_impl_parity', 'has_doc_parity'])
|
||||
ParityStatus = namedtuple("ParityStatus", ["has_impl_parity", "has_doc_parity"])
|
||||
|
||||
'''
|
||||
"""
|
||||
This function expects the parity tracker Markdown file to have the following format:
|
||||
|
||||
```
|
||||
@ -28,31 +28,39 @@ Dict[package_name]
|
||||
-> Dict[api_name]
|
||||
-> ParityStatus
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def parse_parity_tracker_table(file_path):
|
||||
def parse_parity_choice(str):
|
||||
if str in ['Yes', 'No']:
|
||||
return str == 'Yes'
|
||||
if str in ["Yes", "No"]:
|
||||
return str == "Yes"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'{str} is not a supported parity choice. The valid choices are "Yes" and "No".')
|
||||
f'{str} is not a supported parity choice. The valid choices are "Yes" and "No".'
|
||||
)
|
||||
|
||||
parity_tracker_dict = {}
|
||||
|
||||
with open(file_path) as f:
|
||||
all_text = f.read()
|
||||
packages = all_text.split('##')
|
||||
packages = all_text.split("##")
|
||||
for package in packages[1:]:
|
||||
lines = [line.strip() for line in package.split('\n') if line.strip() != '']
|
||||
lines = [line.strip() for line in package.split("\n") if line.strip() != ""]
|
||||
package_name = lines[0]
|
||||
if package_name in parity_tracker_dict:
|
||||
raise RuntimeError(f"Duplicated package name `{package_name}` found in {file_path}")
|
||||
raise RuntimeError(
|
||||
f"Duplicated package name `{package_name}` found in {file_path}"
|
||||
)
|
||||
else:
|
||||
parity_tracker_dict[package_name] = {}
|
||||
for api_status in lines[3:]:
|
||||
api_name, has_impl_parity_str, has_doc_parity_str = (x.strip() for x in api_status.split('|'))
|
||||
api_name, has_impl_parity_str, has_doc_parity_str = (
|
||||
x.strip() for x in api_status.split("|")
|
||||
)
|
||||
parity_tracker_dict[package_name][api_name] = ParityStatus(
|
||||
has_impl_parity=parse_parity_choice(has_impl_parity_str),
|
||||
has_doc_parity=parse_parity_choice(has_doc_parity_str))
|
||||
has_doc_parity=parse_parity_choice(has_doc_parity_str),
|
||||
)
|
||||
|
||||
return parity_tracker_dict
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_nn import wrap_functional
|
||||
|
||||
'''
|
||||
"""
|
||||
`sample_functional` is used by `test_cpp_api_parity.py` to test that Python / C++ API
|
||||
parity test harness works for `torch.nn.functional` functions.
|
||||
|
||||
@ -11,7 +11,8 @@ is the same as the C++ equivalent.
|
||||
|
||||
When `has_parity=false` is passed to `sample_functional`, behavior of `sample_functional`
|
||||
is different from the C++ equivalent.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def sample_functional(x, has_parity):
|
||||
if has_parity:
|
||||
@ -19,6 +20,7 @@ def sample_functional(x, has_parity):
|
||||
else:
|
||||
return x * 4
|
||||
|
||||
|
||||
torch.nn.functional.sample_functional = sample_functional
|
||||
|
||||
SAMPLE_FUNCTIONAL_CPP_SOURCE = """\n
|
||||
@ -44,16 +46,16 @@ Tensor sample_functional(Tensor x, SampleFunctionalFuncOptions options) {
|
||||
functional_tests = [
|
||||
dict(
|
||||
constructor=wrap_functional(F.sample_functional, has_parity=True),
|
||||
cpp_options_args='F::SampleFunctionalFuncOptions(true)',
|
||||
cpp_options_args="F::SampleFunctionalFuncOptions(true)",
|
||||
input_size=(1, 2, 3),
|
||||
fullname='sample_functional_has_parity',
|
||||
fullname="sample_functional_has_parity",
|
||||
has_parity=True,
|
||||
),
|
||||
dict(
|
||||
constructor=wrap_functional(F.sample_functional, has_parity=False),
|
||||
cpp_options_args='F::SampleFunctionalFuncOptions(false)',
|
||||
cpp_options_args="F::SampleFunctionalFuncOptions(false)",
|
||||
input_size=(1, 2, 3),
|
||||
fullname='sample_functional_no_parity',
|
||||
fullname="sample_functional_no_parity",
|
||||
has_parity=False,
|
||||
),
|
||||
# This is to test that setting the `test_cpp_api_parity=False` flag skips
|
||||
@ -61,9 +63,9 @@ functional_tests = [
|
||||
# throw a parity error).
|
||||
dict(
|
||||
constructor=wrap_functional(F.sample_functional, has_parity=False),
|
||||
cpp_options_args='F::SampleFunctionalFuncOptions(false)',
|
||||
cpp_options_args="F::SampleFunctionalFuncOptions(false)",
|
||||
input_size=(1, 2, 3),
|
||||
fullname='sample_functional_THIS_TEST_SHOULD_BE_SKIPPED',
|
||||
fullname="sample_functional_THIS_TEST_SHOULD_BE_SKIPPED",
|
||||
test_cpp_api_parity=False,
|
||||
),
|
||||
]
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
'''
|
||||
"""
|
||||
`SampleModule` is used by `test_cpp_api_parity.py` to test that Python / C++ API
|
||||
parity test harness works for `torch.nn.Module` subclasses.
|
||||
|
||||
@ -9,7 +9,8 @@ is the same as the C++ equivalent.
|
||||
|
||||
When `SampleModule.has_parity` is false, behavior of `forward` / `backward`
|
||||
is different from the C++ equivalent.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class SampleModule(torch.nn.Module):
|
||||
def __init__(self, has_parity, has_submodule):
|
||||
@ -19,7 +20,7 @@ class SampleModule(torch.nn.Module):
|
||||
self.submodule = SampleModule(self.has_parity, False)
|
||||
|
||||
self.has_submodule = has_submodule
|
||||
self.register_parameter('param', torch.nn.Parameter(torch.empty(3, 4)))
|
||||
self.register_parameter("param", torch.nn.Parameter(torch.empty(3, 4)))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
@ -28,12 +29,15 @@ class SampleModule(torch.nn.Module):
|
||||
self.param.fill_(1)
|
||||
|
||||
def forward(self, x):
|
||||
submodule_forward_result = self.submodule(x) if hasattr(self, 'submodule') else 0
|
||||
submodule_forward_result = (
|
||||
self.submodule(x) if hasattr(self, "submodule") else 0
|
||||
)
|
||||
if self.has_parity:
|
||||
return x + self.param * 2 + submodule_forward_result
|
||||
else:
|
||||
return x + self.param * 4 + submodule_forward_result + 3
|
||||
|
||||
|
||||
torch.nn.SampleModule = SampleModule
|
||||
|
||||
SAMPLE_MODULE_CPP_SOURCE = """\n
|
||||
@ -73,31 +77,31 @@ TORCH_MODULE(SampleModule);
|
||||
|
||||
module_tests = [
|
||||
dict(
|
||||
module_name='SampleModule',
|
||||
desc='has_parity',
|
||||
module_name="SampleModule",
|
||||
desc="has_parity",
|
||||
constructor_args=(True, True),
|
||||
cpp_constructor_args='torch::nn::SampleModuleOptions(true, true)',
|
||||
cpp_constructor_args="torch::nn::SampleModuleOptions(true, true)",
|
||||
input_size=(3, 4),
|
||||
cpp_input_args=['torch::randn({3, 4})'],
|
||||
cpp_input_args=["torch::randn({3, 4})"],
|
||||
has_parity=True,
|
||||
),
|
||||
dict(
|
||||
fullname='SampleModule_no_parity',
|
||||
fullname="SampleModule_no_parity",
|
||||
constructor=lambda: SampleModule(has_parity=False, has_submodule=True),
|
||||
cpp_constructor_args='torch::nn::SampleModuleOptions(false, true)',
|
||||
cpp_constructor_args="torch::nn::SampleModuleOptions(false, true)",
|
||||
input_size=(3, 4),
|
||||
cpp_input_args=['torch::randn({3, 4})'],
|
||||
cpp_input_args=["torch::randn({3, 4})"],
|
||||
has_parity=False,
|
||||
),
|
||||
# This is to test that setting the `test_cpp_api_parity=False` flag skips
|
||||
# the C++ API parity test accordingly (otherwise this test would run and
|
||||
# throw a parity error).
|
||||
dict(
|
||||
fullname='SampleModule_THIS_TEST_SHOULD_BE_SKIPPED',
|
||||
fullname="SampleModule_THIS_TEST_SHOULD_BE_SKIPPED",
|
||||
constructor=lambda: SampleModule(False, True),
|
||||
cpp_constructor_args='torch::nn::SampleModuleOptions(false, true)',
|
||||
cpp_constructor_args="torch::nn::SampleModuleOptions(false, true)",
|
||||
input_size=(3, 4),
|
||||
cpp_input_args=['torch::randn({3, 4})'],
|
||||
cpp_input_args=["torch::randn({3, 4})"],
|
||||
test_cpp_api_parity=False,
|
||||
),
|
||||
]
|
||||
|
@ -1,37 +1,33 @@
|
||||
from collections import namedtuple
|
||||
import unittest
|
||||
import os
|
||||
import warnings
|
||||
import shutil
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
import torch.testing._internal.common_nn as common_nn
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
||||
# Note that this namedtuple is for C++ parity test mechanism's internal use.
|
||||
# For guidance on how to add a new C++ parity test, please see
|
||||
# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
|
||||
TorchNNModuleTestParams = namedtuple(
|
||||
'TorchNNModuleTestParams',
|
||||
"TorchNNModuleTestParams",
|
||||
[
|
||||
# NN module name (e.g. "BCELoss")
|
||||
'module_name',
|
||||
|
||||
"module_name",
|
||||
# Unique identifier for this module config (e.g. "BCELoss_weights_cuda")
|
||||
'module_variant_name',
|
||||
|
||||
"module_variant_name",
|
||||
# An instance of an NN test class (e.g. `CriterionTest`) which stores
|
||||
# necessary information (e.g. input / target / extra_args) for running the Python test
|
||||
'test_instance',
|
||||
|
||||
"test_instance",
|
||||
# Constructor arguments passed to the C++ module constructor, which must be
|
||||
# strictly equivalent to the Python module constructor arguments
|
||||
# (e.g. `torch::nn::BCELossOptions().weight(torch::rand(10))`,
|
||||
# which is strictly equivalent to passing `torch.rand(10)` to `torch.nn.BCELoss`
|
||||
# constructor in Python)
|
||||
'cpp_constructor_args',
|
||||
|
||||
"cpp_constructor_args",
|
||||
# All arguments used in NN module's forward pass.
|
||||
# Please see `compute_arg_dict` function for details on how we construct this dict.
|
||||
# (e.g.
|
||||
@ -44,42 +40,35 @@ TorchNNModuleTestParams = namedtuple(
|
||||
# }
|
||||
# ```
|
||||
# )
|
||||
'arg_dict',
|
||||
|
||||
"arg_dict",
|
||||
# Whether we expect this NN module test to pass the Python/C++ parity test
|
||||
# (e.g. `True`)
|
||||
'has_parity',
|
||||
|
||||
"has_parity",
|
||||
# Device (e.g. "cuda")
|
||||
'device',
|
||||
|
||||
"device",
|
||||
# Temporary folder to store C++ outputs (to be compared with Python outputs later)
|
||||
'cpp_tmp_folder',
|
||||
]
|
||||
"cpp_tmp_folder",
|
||||
],
|
||||
)
|
||||
|
||||
# Note that this namedtuple is for C++ parity test mechanism's internal use.
|
||||
# For guidance on how to add a new C++ parity test, please see
|
||||
# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
|
||||
TorchNNFunctionalTestParams = namedtuple(
|
||||
'TorchNNFunctionalTestParams',
|
||||
"TorchNNFunctionalTestParams",
|
||||
[
|
||||
# NN functional name (e.g. "binary_cross_entropy")
|
||||
'functional_name',
|
||||
|
||||
"functional_name",
|
||||
# Unique identifier for this functional config (e.g. "BCELoss_no_reduce_cuda")
|
||||
'functional_variant_name',
|
||||
|
||||
"functional_variant_name",
|
||||
# An instance of an NN test class (e.g. `NewModuleTest`) which stores
|
||||
# necessary information (e.g. input / target / extra_args) for running the Python test
|
||||
'test_instance',
|
||||
|
||||
"test_instance",
|
||||
# The C++ function call that is strictly equivalent to the Python function call
|
||||
# (e.g. "F::binary_cross_entropy(
|
||||
# i, t.to(i.options()),F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))",
|
||||
# which is strictly equivalent to `F.binary_cross_entropy(i, t.type_as(i), reduction='none')` in Python)
|
||||
'cpp_function_call',
|
||||
|
||||
"cpp_function_call",
|
||||
# All arguments used in NN functional's function call.
|
||||
# Please see `compute_arg_dict` function for details on how we construct this dict.
|
||||
# (e.g.
|
||||
@ -92,21 +81,18 @@ TorchNNFunctionalTestParams = namedtuple(
|
||||
# }
|
||||
# ```
|
||||
# )
|
||||
'arg_dict',
|
||||
|
||||
"arg_dict",
|
||||
# Whether we expect this NN functional test to pass the Python/C++ parity test
|
||||
# (e.g. `True`)
|
||||
'has_parity',
|
||||
|
||||
"has_parity",
|
||||
# Device (e.g. "cuda")
|
||||
'device',
|
||||
|
||||
"device",
|
||||
# Temporary folder to store C++ outputs (to be compared with Python outputs later)
|
||||
'cpp_tmp_folder',
|
||||
]
|
||||
"cpp_tmp_folder",
|
||||
],
|
||||
)
|
||||
|
||||
CppArg = namedtuple('CppArg', ['name', 'value'])
|
||||
CppArg = namedtuple("CppArg", ["name", "value"])
|
||||
|
||||
TORCH_NN_COMMON_TEST_HARNESS = """
|
||||
#include <torch/script.h>
|
||||
@ -139,21 +125,27 @@ torch::Tensor _rand_tensor_non_equal(torch::IntArrayRef size) {
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def compile_cpp_code_inline(name, cpp_sources, functions):
|
||||
cpp_module = torch.utils.cpp_extension.load_inline(
|
||||
name=name,
|
||||
cpp_sources=cpp_sources,
|
||||
extra_cflags=['-g'], # Enable debug symbols by default for debugging test failures.
|
||||
extra_cflags=[
|
||||
"-g"
|
||||
], # Enable debug symbols by default for debugging test failures.
|
||||
functions=functions,
|
||||
verbose=False,
|
||||
)
|
||||
return cpp_module
|
||||
|
||||
|
||||
def compute_temp_file_path(cpp_tmp_folder, variant_name, file_suffix):
|
||||
return os.path.join(cpp_tmp_folder, f'{variant_name}_{file_suffix}.pt')
|
||||
return os.path.join(cpp_tmp_folder, f"{variant_name}_{file_suffix}.pt")
|
||||
|
||||
|
||||
def is_torch_nn_functional_test(test_params_dict):
|
||||
return 'wrap_functional' in str(test_params_dict.get('constructor', ''))
|
||||
return "wrap_functional" in str(test_params_dict.get("constructor", ""))
|
||||
|
||||
|
||||
def convert_to_list(python_input):
|
||||
if isinstance(python_input, torch.Tensor):
|
||||
@ -161,31 +153,46 @@ def convert_to_list(python_input):
|
||||
else:
|
||||
return list(python_input)
|
||||
|
||||
|
||||
def set_python_tensors_requires_grad(python_tensors):
|
||||
return [tensor.requires_grad_(True) if tensor.dtype != torch.long else tensor for tensor in python_tensors]
|
||||
return [
|
||||
tensor.requires_grad_(True) if tensor.dtype != torch.long else tensor
|
||||
for tensor in python_tensors
|
||||
]
|
||||
|
||||
|
||||
def move_python_tensors_to_device(python_tensors, device):
|
||||
return [tensor.to(device) for tensor in python_tensors]
|
||||
|
||||
|
||||
def has_test(unit_test_class, test_name):
|
||||
return hasattr(unit_test_class, test_name)
|
||||
|
||||
|
||||
def add_test(unit_test_class, test_name, test_fn):
|
||||
if has_test(unit_test_class, test_name):
|
||||
raise RuntimeError("Found two tests with the same name: " + test_name)
|
||||
setattr(unit_test_class, test_name, test_fn)
|
||||
|
||||
|
||||
def set_cpp_tensors_requires_grad(cpp_tensor_stmts, python_tensors):
|
||||
assert len(cpp_tensor_stmts) == len(python_tensors)
|
||||
return [f'{tensor_stmt}.requires_grad_(true)' if tensor.dtype != torch.long else tensor_stmt
|
||||
for tensor_stmt, (_, tensor) in zip(cpp_tensor_stmts, python_tensors)]
|
||||
return [
|
||||
f"{tensor_stmt}.requires_grad_(true)"
|
||||
if tensor.dtype != torch.long
|
||||
else tensor_stmt
|
||||
for tensor_stmt, (_, tensor) in zip(cpp_tensor_stmts, python_tensors)
|
||||
]
|
||||
|
||||
|
||||
def move_cpp_tensors_to_device(cpp_tensor_stmts, device):
|
||||
return [f'{tensor_stmt}.to("{device}")' for tensor_stmt in cpp_tensor_stmts]
|
||||
|
||||
|
||||
def is_criterion_test(test_instance):
|
||||
return isinstance(test_instance, common_nn.CriterionTest)
|
||||
|
||||
|
||||
# This function computes the following:
|
||||
# - What variable declaration statements should show up in the C++ parity test function
|
||||
# - What arguments should be passed into the C++ module/functional's forward function
|
||||
@ -213,26 +220,42 @@ def compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params):
|
||||
cpp_forward_args_symbols.append(arg_name)
|
||||
return args_stmts
|
||||
|
||||
cpp_forward_input_args_stmts = set_cpp_tensors_requires_grad(move_cpp_tensors_to_device(
|
||||
add_cpp_forward_args(test_params.arg_dict['input']), device), test_params.arg_dict['input'])
|
||||
cpp_forward_input_args_stmts = set_cpp_tensors_requires_grad(
|
||||
move_cpp_tensors_to_device(
|
||||
add_cpp_forward_args(test_params.arg_dict["input"]), device
|
||||
),
|
||||
test_params.arg_dict["input"],
|
||||
)
|
||||
cpp_forward_target_args_stmts = move_cpp_tensors_to_device(
|
||||
add_cpp_forward_args(test_params.arg_dict['target']), device)
|
||||
add_cpp_forward_args(test_params.arg_dict["target"]), device
|
||||
)
|
||||
cpp_forward_extra_args_stmts = move_cpp_tensors_to_device(
|
||||
add_cpp_forward_args(test_params.arg_dict['extra_args']), device)
|
||||
add_cpp_forward_args(test_params.arg_dict["extra_args"]), device
|
||||
)
|
||||
|
||||
# Build the list of other arguments needed
|
||||
cpp_other_args_stmts = []
|
||||
for arg_name, _ in test_params.arg_dict['other']:
|
||||
for arg_name, _ in test_params.arg_dict["other"]:
|
||||
cpp_other_args_stmts.append(f'auto {arg_name} = arg_dict.at("{arg_name}")')
|
||||
cpp_other_args_stmts = move_cpp_tensors_to_device(cpp_other_args_stmts, device)
|
||||
|
||||
cpp_args_construction_stmts = cpp_forward_input_args_stmts + cpp_forward_target_args_stmts + \
|
||||
cpp_forward_extra_args_stmts + cpp_other_args_stmts
|
||||
cpp_args_construction_stmts = (
|
||||
cpp_forward_input_args_stmts
|
||||
+ cpp_forward_target_args_stmts
|
||||
+ cpp_forward_extra_args_stmts
|
||||
+ cpp_other_args_stmts
|
||||
)
|
||||
|
||||
return cpp_args_construction_stmts, cpp_forward_args_symbols
|
||||
|
||||
|
||||
def serialize_arg_dict_as_script_module(arg_dict):
|
||||
arg_dict_flat = dict(arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other'])
|
||||
arg_dict_flat = dict(
|
||||
arg_dict["input"]
|
||||
+ arg_dict["target"]
|
||||
+ arg_dict["extra_args"]
|
||||
+ arg_dict["other"]
|
||||
)
|
||||
arg_dict_module = torch.nn.Module()
|
||||
for arg_name, arg_value in arg_dict_flat.items():
|
||||
assert isinstance(arg_value, torch.Tensor)
|
||||
@ -240,6 +263,7 @@ def serialize_arg_dict_as_script_module(arg_dict):
|
||||
|
||||
return torch.jit.script(arg_dict_module)
|
||||
|
||||
|
||||
# NOTE: any argument symbol used in `cpp_constructor_args` / `cpp_options_args` / `cpp_function_call`
|
||||
# must have a mapping in `cpp_var_map`.
|
||||
#
|
||||
@ -270,38 +294,47 @@ def serialize_arg_dict_as_script_module(arg_dict):
|
||||
# ```
|
||||
def compute_arg_dict(test_params_dict, test_instance):
|
||||
arg_dict = {
|
||||
'input': [],
|
||||
'target': [],
|
||||
'extra_args': [],
|
||||
'other': [],
|
||||
"input": [],
|
||||
"target": [],
|
||||
"extra_args": [],
|
||||
"other": [],
|
||||
}
|
||||
|
||||
def put_args_into_arg_dict(arg_type, arg_type_prefix, args):
|
||||
for i, arg in enumerate(args):
|
||||
arg_dict[arg_type].append(CppArg(name=arg_type_prefix + str(i), value=arg))
|
||||
|
||||
put_args_into_arg_dict('input', 'i', convert_to_list(test_instance._get_input()))
|
||||
put_args_into_arg_dict("input", "i", convert_to_list(test_instance._get_input()))
|
||||
if is_criterion_test(test_instance):
|
||||
put_args_into_arg_dict('target', 't', convert_to_list(test_instance._get_target()))
|
||||
put_args_into_arg_dict(
|
||||
"target", "t", convert_to_list(test_instance._get_target())
|
||||
)
|
||||
if test_instance.extra_args:
|
||||
put_args_into_arg_dict('extra_args', 'e', convert_to_list(test_instance.extra_args))
|
||||
put_args_into_arg_dict(
|
||||
"extra_args", "e", convert_to_list(test_instance.extra_args)
|
||||
)
|
||||
|
||||
cpp_var_map = test_params_dict.get('cpp_var_map', {})
|
||||
cpp_var_map = test_params_dict.get("cpp_var_map", {})
|
||||
for arg_name, arg_value in cpp_var_map.items():
|
||||
if isinstance(arg_value, str):
|
||||
if arg_value == '_get_input()':
|
||||
arg_dict['other'].append(CppArg(name=arg_name, value=test_instance._get_input()))
|
||||
if arg_value == "_get_input()":
|
||||
arg_dict["other"].append(
|
||||
CppArg(name=arg_name, value=test_instance._get_input())
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"`{arg_name}` has unsupported string value: {arg_value}")
|
||||
raise RuntimeError(
|
||||
f"`{arg_name}` has unsupported string value: {arg_value}"
|
||||
)
|
||||
elif isinstance(arg_value, torch.Tensor):
|
||||
arg_dict['other'].append(CppArg(name=arg_name, value=arg_value))
|
||||
arg_dict["other"].append(CppArg(name=arg_name, value=arg_value))
|
||||
else:
|
||||
raise RuntimeError(f"`{arg_name}` has unsupported value: {arg_value}")
|
||||
|
||||
return arg_dict
|
||||
|
||||
|
||||
def decorate_test_fn(test_fn, test_cuda, has_impl_parity, device):
|
||||
if device == 'cuda':
|
||||
if device == "cuda":
|
||||
test_fn = unittest.skipIf(not TEST_CUDA, "CUDA unavailable")(test_fn)
|
||||
test_fn = unittest.skipIf(not test_cuda, "Excluded from CUDA tests")(test_fn)
|
||||
|
||||
@ -313,7 +346,8 @@ def decorate_test_fn(test_fn, test_cuda, has_impl_parity, device):
|
||||
|
||||
return test_fn
|
||||
|
||||
MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE = '''
|
||||
|
||||
MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE = """
|
||||
What should I do when C++ API parity test is failing?
|
||||
|
||||
- If you are changing the implementation of an existing `torch.nn` module / `torch.nn.functional` function:
|
||||
@ -337,13 +371,15 @@ also file an issue when you do this.
|
||||
|
||||
For more details on how to add a C++ API parity test, please see:
|
||||
NOTE [How to check NN module / functional API parity between Python and C++ frontends]
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def generate_error_msg(name, cpp_value, python_value):
|
||||
return (
|
||||
"Parity test failed: {} in C++ has value: {}, "
|
||||
"which does not match the corresponding value in Python: {}.\n{}").format(
|
||||
name, cpp_value, python_value, MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE)
|
||||
"which does not match the corresponding value in Python: {}.\n{}"
|
||||
).format(name, cpp_value, python_value, MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE)
|
||||
|
||||
|
||||
def try_remove_folder(folder_path):
|
||||
if os.path.exists(folder_path):
|
||||
@ -351,4 +387,6 @@ def try_remove_folder(folder_path):
|
||||
try:
|
||||
shutil.rmtree(folder_path)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Non-blocking folder removal fails with the following error:\n{str(e)}")
|
||||
warnings.warn(
|
||||
f"Non-blocking folder removal fails with the following error:\n{str(e)}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user