Enable UFMT on test/cpp_api_parity, test/cpp_extensions, test/create_dummy_torchscript_model.py, test/custom_backend, test/custom_operator (#123518)

Partially addresses #123062

Ran lintrunner on:

- `test/cpp_api_parity`
- `test/cpp_extensions`
- `test/create_dummy_torchscript_model.py`
- `test/custom_backend`
- `test/custom_operator`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123518
Approved by: https://github.com/huydhn
This commit is contained in:
Yuanhao Ji
2024-04-08 20:18:42 +00:00
committed by PyTorch MergeBot
parent b279034e5a
commit c797fbc4e1
12 changed files with 529 additions and 312 deletions

View File

@ -14,26 +14,39 @@
# 4. Compare Python/C++ functional's forward output. If they are the same, then we
# have implementation parity between Python/C++ module.
import os
import pprint
import re
import tempfile
from string import Template
import re
import pprint
import os
import torch
from cpp_api_parity.utils import TorchNNFunctionalTestParams, TORCH_NN_COMMON_TEST_HARNESS, \
compile_cpp_code_inline, set_python_tensors_requires_grad, move_python_tensors_to_device, \
add_test, compute_cpp_args_construction_stmts_and_forward_arg_symbols, serialize_arg_dict_as_script_module, \
compute_arg_dict, decorate_test_fn, compute_temp_file_path, generate_error_msg, is_torch_nn_functional_test, \
try_remove_folder
from cpp_api_parity.sample_functional import SAMPLE_FUNCTIONAL_CPP_SOURCE
from cpp_api_parity.utils import (
add_test,
compile_cpp_code_inline,
compute_arg_dict,
compute_cpp_args_construction_stmts_and_forward_arg_symbols,
compute_temp_file_path,
decorate_test_fn,
generate_error_msg,
is_torch_nn_functional_test,
move_python_tensors_to_device,
serialize_arg_dict_as_script_module,
set_python_tensors_requires_grad,
TORCH_NN_COMMON_TEST_HARNESS,
TorchNNFunctionalTestParams,
try_remove_folder,
)
# Expected substitutions:
#
# ${functional_variant_name} (e.g. `BCELoss_no_reduce`)
# ${cpp_args_construction_stmts}
# ${cpp_function_call}
TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template("""
TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template(
"""
void ${functional_variant_name}_test_forward(
const std::string& arg_dict_file_path,
const std::string& forward_output_file_path) {
@ -56,17 +69,24 @@ void ${functional_variant_name}_test_forward(
// Save the output into a file to be compared in Python later
write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path);
}
""")
"""
)
def run_forward(unit_test_class, test_params):
device = test_params.device
inputs = set_python_tensors_requires_grad(move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict['input']], device))
inputs = set_python_tensors_requires_grad(
move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict["input"]], device
)
)
inputs += move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict['target']], device)
[arg_value for _, arg_value in test_params.arg_dict["target"]], device
)
inputs += move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict['extra_args']], device)
[arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device
)
# Some functionals (such as `F.rrelu`) create random tensors in their call path.
# To make sure the random tensors created are the same in Python/C++, we need
@ -76,6 +96,7 @@ def run_forward(unit_test_class, test_params):
return python_output
def test_forward(unit_test_class, test_params):
functional_variant_name = test_params.functional_variant_name
cpp_tmp_folder = test_params.cpp_tmp_folder
@ -87,105 +108,149 @@ def test_forward(unit_test_class, test_params):
python_output = run_forward(unit_test_class, test_params)
# Save Python arguments to be used from C++ function
arg_dict_file_path = compute_temp_file_path(cpp_tmp_folder, functional_variant_name, 'arg_dict')
arg_dict_file_path = compute_temp_file_path(
cpp_tmp_folder, functional_variant_name, "arg_dict"
)
serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path)
cpp_test_name = f'{test_params.functional_variant_name}_test_forward'
cpp_test_fn = getattr(unit_test_class.functional_impl_check_cpp_module, cpp_test_name)
cpp_test_name = f"{test_params.functional_variant_name}_test_forward"
cpp_test_fn = getattr(
unit_test_class.functional_impl_check_cpp_module, cpp_test_name
)
def run_cpp_test_fn_and_check_output():
forward_output_file_path = compute_temp_file_path(cpp_tmp_folder, functional_variant_name, 'forward_output')
forward_output_file_path = compute_temp_file_path(
cpp_tmp_folder, functional_variant_name, "forward_output"
)
cpp_test_fn(arg_dict_file_path, forward_output_file_path)
cpp_output = torch.load(forward_output_file_path)
# Check that forward outputs are equal
unit_test_class.assertEqual(
python_output, cpp_output,
msg=generate_error_msg("forward output", cpp_output, python_output))
python_output,
cpp_output,
msg=generate_error_msg("forward output", cpp_output, python_output),
)
run_cpp_test_fn_and_check_output()
# Remove temporary folder that stores C++ outputs
try_remove_folder(cpp_tmp_folder)
def compute_functional_name(test_params_dict):
def camel_case_to_snake_case(camel_case_str):
return re.sub(r'(?<!^)(?=[A-Z])', '_', camel_case_str).lower()
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_str).lower()
if 'cpp_options_args' in test_params_dict:
if "cpp_options_args" in test_params_dict:
# Expected format for `cpp_options_args`: `F::FunctionalFuncOptions(...)`
# Example output: `binary_cross_entropy`
return camel_case_to_snake_case(
test_params_dict['cpp_options_args'].split('(')[0].replace('F::', '').replace('FuncOptions', ''))
elif 'cpp_function_call' in test_params_dict:
test_params_dict["cpp_options_args"]
.split("(")[0]
.replace("F::", "")
.replace("FuncOptions", "")
)
elif "cpp_function_call" in test_params_dict:
# Expected format for `cpp_function_call`: `F::functional_name(...)`
# Example output: `binary_cross_entropy`
return test_params_dict['cpp_function_call'].split('(')[0].replace('F::', '')
return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
else:
raise RuntimeError(
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{}".format(
pprint.pformat(test_params_dict)))
pprint.pformat(test_params_dict)
)
)
def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
if 'cpp_function_call' in test_params_dict:
return test_params_dict['cpp_function_call']
elif 'cpp_options_args' in test_params_dict:
cpp_forward_args_symbols = [arg_name for arg_name, _ in
arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args']]
return 'F::{}({}, {})'.format(
functional_name, ", ".join(cpp_forward_args_symbols), test_params_dict['cpp_options_args'])
if "cpp_function_call" in test_params_dict:
return test_params_dict["cpp_function_call"]
elif "cpp_options_args" in test_params_dict:
cpp_forward_args_symbols = [
arg_name
for arg_name, _ in arg_dict["input"]
+ arg_dict["target"]
+ arg_dict["extra_args"]
]
return "F::{}({}, {})".format(
functional_name,
", ".join(cpp_forward_args_symbols),
test_params_dict["cpp_options_args"],
)
else:
raise RuntimeError(
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{}".format(
pprint.pformat(test_params_dict)))
pprint.pformat(test_params_dict)
)
)
def process_test_params_for_functional(test_params_dict, device, test_instance_class):
test_instance = test_instance_class(**test_params_dict)
functional_name = compute_functional_name(test_params_dict)
assert test_instance.get_name().startswith('test_')
assert test_instance.get_name().startswith("test_")
# Example output: `BCELoss_no_reduce_cuda`
functional_variant_name = test_instance.get_name()[5:] + (('_' + device) if device != 'cpu' else '')
functional_variant_name = test_instance.get_name()[5:] + (
("_" + device) if device != "cpu" else ""
)
arg_dict = compute_arg_dict(test_params_dict, test_instance)
return TorchNNFunctionalTestParams(
functional_name=functional_name,
functional_variant_name=functional_variant_name,
test_instance=test_instance,
cpp_function_call=compute_cpp_function_call(test_params_dict, arg_dict, functional_name),
cpp_function_call=compute_cpp_function_call(
test_params_dict, arg_dict, functional_name
),
arg_dict=arg_dict,
has_parity=test_params_dict.get('has_parity', True),
has_parity=test_params_dict.get("has_parity", True),
device=device,
cpp_tmp_folder=tempfile.mkdtemp(),
)
def write_test_to_test_class(
unit_test_class, test_params_dict, test_instance_class, parity_table, devices):
unit_test_class, test_params_dict, test_instance_class, parity_table, devices
):
assert is_torch_nn_functional_test(test_params_dict)
assert 'cpp_options_args' in test_params_dict or 'cpp_function_call' in test_params_dict, (
assert (
"cpp_options_args" in test_params_dict
or "cpp_function_call" in test_params_dict
), (
"To enable C++ API parity test, "
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{}. \n"
"If you are interested in adding the C++ API parity test, please see:\n"
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
).format(pprint.pformat(test_params_dict))
).format(
pprint.pformat(test_params_dict)
)
assert not ('cpp_options_args' in test_params_dict and 'cpp_function_call' in test_params_dict), (
assert not (
"cpp_options_args" in test_params_dict
and "cpp_function_call" in test_params_dict
), (
"Only one of `cpp_options_args` and `cpp_function_call` entries "
f"should be present in test params dict:\n{pprint.pformat(test_params_dict)}")
f"should be present in test params dict:\n{pprint.pformat(test_params_dict)}"
)
functional_name = compute_functional_name(test_params_dict)
assert hasattr(torch.nn.functional, functional_name), \
"`torch.nn.functional` doesn't have function `{}`. (Discovered while processing\n{}.)".format(
functional_name, pprint.pformat(test_params_dict))
assert hasattr(
torch.nn.functional, functional_name
), "`torch.nn.functional` doesn't have function `{}`. (Discovered while processing\n{}.)".format(
functional_name, pprint.pformat(test_params_dict)
)
functional_full_name = 'F::' + functional_name
functional_full_name = "F::" + functional_name
assert functional_full_name in parity_table['torch::nn::functional'], (
assert functional_full_name in parity_table["torch::nn::functional"], (
"Please add `{}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. "
"(Discovered while processing\n{}.)").format(functional_full_name, pprint.pformat(test_params_dict))
"(Discovered while processing\n{}.)"
).format(functional_full_name, pprint.pformat(test_params_dict))
for device in devices:
test_params = process_test_params_for_functional(
@ -194,24 +259,37 @@ def write_test_to_test_class(
test_instance_class=test_instance_class,
)
try_remove_folder(test_params.cpp_tmp_folder)
unit_test_name = f'test_torch_nn_functional_{test_params.functional_variant_name}'
unit_test_name = (
f"test_torch_nn_functional_{test_params.functional_variant_name}"
)
unit_test_class.functional_test_params_map[unit_test_name] = test_params
def test_fn(self):
test_forward(
unit_test_class=self, test_params=unit_test_class.functional_test_params_map[self._testMethodName])
unit_test_class=self,
test_params=unit_test_class.functional_test_params_map[
self._testMethodName
],
)
test_fn = decorate_test_fn(
test_fn=test_fn,
test_cuda=test_params_dict.get('test_cuda', True),
has_impl_parity=parity_table['torch::nn::functional'][functional_full_name][0] and
test_params_dict.get('has_parity', True),
device=device)
test_cuda=test_params_dict.get("test_cuda", True),
has_impl_parity=parity_table["torch::nn::functional"][functional_full_name][
0
]
and test_params_dict.get("has_parity", True),
device=device,
)
add_test(unit_test_class, unit_test_name, test_fn)
def generate_test_cpp_sources(test_params, template):
cpp_args_construction_stmts, _ = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
(
cpp_args_construction_stmts,
_,
) = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
test_cpp_sources = template.substitute(
functional_variant_name=test_params.functional_variant_name,
@ -220,19 +298,21 @@ def generate_test_cpp_sources(test_params, template):
)
return test_cpp_sources
# Build all C++ tests together, instead of once per test.
def build_cpp_tests(unit_test_class, print_cpp_source=False):
assert len(unit_test_class.functional_test_params_map) > 0
cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_FUNCTIONAL_CPP_SOURCE
functions = []
for test_params in unit_test_class.functional_test_params_map.values():
cpp_sources += generate_test_cpp_sources(test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD)
functions.append(f'{test_params.functional_variant_name}_test_forward')
cpp_sources += generate_test_cpp_sources(
test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD
)
functions.append(f"{test_params.functional_variant_name}_test_forward")
if print_cpp_source:
print(cpp_sources)
cpp_module = compile_cpp_code_inline(
name='functional_impl_check',
cpp_sources=cpp_sources,
functions=functions)
name="functional_impl_check", cpp_sources=cpp_sources, functions=functions
)
unit_test_class.functional_impl_check_cpp_module = cpp_module

View File

@ -14,19 +14,31 @@
# 4. Compare Python/C++ module's forward output and backward gradients. If they
# are the same, then we have implementation parity between Python/C++ module.
import tempfile
from string import Template
import types
import pprint
import os
import pprint
import tempfile
import types
from string import Template
import torch
from cpp_api_parity.utils import TorchNNModuleTestParams, TORCH_NN_COMMON_TEST_HARNESS, \
compile_cpp_code_inline, set_python_tensors_requires_grad, move_python_tensors_to_device, \
add_test, compute_cpp_args_construction_stmts_and_forward_arg_symbols, serialize_arg_dict_as_script_module, \
compute_arg_dict, decorate_test_fn, compute_temp_file_path, generate_error_msg, is_torch_nn_functional_test, \
try_remove_folder
from cpp_api_parity.sample_module import SAMPLE_MODULE_CPP_SOURCE
from cpp_api_parity.utils import (
add_test,
compile_cpp_code_inline,
compute_arg_dict,
compute_cpp_args_construction_stmts_and_forward_arg_symbols,
compute_temp_file_path,
decorate_test_fn,
generate_error_msg,
is_torch_nn_functional_test,
move_python_tensors_to_device,
serialize_arg_dict_as_script_module,
set_python_tensors_requires_grad,
TORCH_NN_COMMON_TEST_HARNESS,
TorchNNModuleTestParams,
try_remove_folder,
)
# Expected substitutions:
#
@ -36,7 +48,8 @@ from cpp_api_parity.sample_module import SAMPLE_MODULE_CPP_SOURCE
# ${cpp_constructor_args}
# ${device}
# ${cpp_forward_args_symbols}
TORCH_NN_MODULE_TEST_FORWARD_BACKWARD = Template("""
TORCH_NN_MODULE_TEST_FORWARD_BACKWARD = Template(
"""
void ${module_variant_name}_test_forward_backward(
const std::string& arg_dict_file_path,
const std::string& module_file_path,
@ -85,18 +98,27 @@ void ${module_variant_name}_test_forward_backward(
write_ivalue_to_file(torch::IValue(grad_dict), backward_grad_dict_file_path);
}
""")
"""
)
def run_python_forward_backward(unit_test_class, test_params):
device = test_params.device
module = test_params.test_instance.constructor(*test_params.test_instance.constructor_args).to(device)
module = test_params.test_instance.constructor(
*test_params.test_instance.constructor_args
).to(device)
inputs = set_python_tensors_requires_grad(move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict['input']], device))
inputs = set_python_tensors_requires_grad(
move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict["input"]], device
)
)
inputs += move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict['target']], device)
[arg_value for _, arg_value in test_params.arg_dict["target"]], device
)
inputs += move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict['extra_args']], device)
[arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device
)
# Some modules (such as `RReLU`) create random tensors in their forward pass.
# To make sure the random tensors created are the same in Python/C++, we need
@ -130,6 +152,7 @@ def run_python_forward_backward(unit_test_class, test_params):
return script_module, python_output, python_grad_dict
def test_forward_backward(unit_test_class, test_params):
module_variant_name = test_params.module_variant_name
cpp_tmp_folder = test_params.cpp_tmp_folder
@ -138,76 +161,113 @@ def test_forward_backward(unit_test_class, test_params):
os.mkdir(cpp_tmp_folder)
# Run forward and backward on Python module
script_module, python_output, python_grad_dict = run_python_forward_backward(unit_test_class, test_params)
script_module, python_output, python_grad_dict = run_python_forward_backward(
unit_test_class, test_params
)
# Save Python module and arguments to be used from C++ function
module_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'module')
arg_dict_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'arg_dict')
module_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "module"
)
arg_dict_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "arg_dict"
)
script_module.save(module_file_path)
serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path)
cpp_test_name = f'{test_params.module_variant_name}_test_forward_backward'
cpp_test_name = f"{test_params.module_variant_name}_test_forward_backward"
cpp_test_fn = getattr(unit_test_class.module_impl_check_cpp_module, cpp_test_name)
def run_cpp_test_fn_and_check_output():
forward_output_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'forward_output')
backward_grad_dict_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'backward_grad_dict')
forward_output_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "forward_output"
)
backward_grad_dict_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "backward_grad_dict"
)
cpp_test_fn(arg_dict_file_path, module_file_path, forward_output_file_path, backward_grad_dict_file_path)
cpp_test_fn(
arg_dict_file_path,
module_file_path,
forward_output_file_path,
backward_grad_dict_file_path,
)
cpp_output = torch.load(forward_output_file_path)
cpp_grad_dict = torch.load(backward_grad_dict_file_path)
# Check that forward outputs are equal
unit_test_class.assertEqual(python_output, cpp_output,
msg=generate_error_msg("forward output", cpp_output, python_output))
unit_test_class.assertEqual(
python_output,
cpp_output,
msg=generate_error_msg("forward output", cpp_output, python_output),
)
# Check that module parameter gradients are equal after backward pass
unit_test_class.assertEqual(
len(python_grad_dict), len(cpp_grad_dict),
msg=generate_error_msg("# of parameters", len(cpp_grad_dict), len(python_grad_dict)))
len(python_grad_dict),
len(cpp_grad_dict),
msg=generate_error_msg(
"# of parameters", len(cpp_grad_dict), len(python_grad_dict)
),
)
for key in python_grad_dict:
param_name = None
for suffix in ['_grad', '_grad_indices', '_grad_values']:
for suffix in ["_grad", "_grad_indices", "_grad_values"]:
if key.endswith(suffix):
param_name = key[:-len(suffix)]
param_name = key[: -len(suffix)]
break
assert param_name is not None
sparsity_str = 'sparse' if key.endswith(('_grad_indices', '_grad_values')) else 'dense'
sparsity_str = (
"sparse" if key.endswith(("_grad_indices", "_grad_values")) else "dense"
)
unit_test_class.assertTrue(
key in cpp_grad_dict,
msg=generate_error_msg(
f"\"Does module have a parameter named `{param_name}` with {sparsity_str} gradient?\"",
False, True))
f'"Does module have a parameter named `{param_name}` with {sparsity_str} gradient?"',
False,
True,
),
)
unit_test_class.assertEqual(
python_grad_dict[key], cpp_grad_dict[key],
python_grad_dict[key],
cpp_grad_dict[key],
msg=generate_error_msg(
f"`{param_name}`'s {sparsity_str} gradient (`{key}`)",
cpp_grad_dict[key], python_grad_dict[key]))
cpp_grad_dict[key],
python_grad_dict[key],
),
)
run_cpp_test_fn_and_check_output()
# Remove temporary folder that stores C++ outputs
try_remove_folder(cpp_tmp_folder)
def compute_module_name(test_params_dict):
fullname = test_params_dict.get('fullname', None)
fullname = test_params_dict.get("fullname", None)
if fullname:
module_name = fullname.split('_')[0]
module_name = fullname.split("_")[0]
else:
module_name = test_params_dict.get('module_name')
module_name = test_params_dict.get("module_name")
return module_name
def process_test_params_for_module(test_params_dict, device, test_instance_class):
module_name = compute_module_name(test_params_dict)
test_params_dict['constructor'] = test_params_dict.get('constructor', getattr(torch.nn, module_name))
test_params_dict["constructor"] = test_params_dict.get(
"constructor", getattr(torch.nn, module_name)
)
test_instance = test_instance_class(**test_params_dict)
assert test_instance.get_name().startswith('test_')
assert test_instance.get_name().startswith("test_")
# Example output: `BCELoss_weights_cuda`
module_variant_name = test_instance.get_name()[5:] + (('_' + device) if device != 'cpu' else '')
module_variant_name = test_instance.get_name()[5:] + (
("_" + device) if device != "cpu" else ""
)
if 'constructor_args' in test_params_dict:
assert 'cpp_constructor_args' in test_params_dict, (
if "constructor_args" in test_params_dict:
assert "cpp_constructor_args" in test_params_dict, (
"If `constructor_args` is present in test params dict, to enable C++ API parity test, "
f"`cpp_constructor_args` must be present in:\n{pprint.pformat(test_params_dict)}"
"If you are interested in adding the C++ API parity test, please see:\n"
@ -219,15 +279,17 @@ def process_test_params_for_module(test_params_dict, device, test_instance_class
module_name=module_name,
module_variant_name=module_variant_name,
test_instance=test_instance,
cpp_constructor_args=test_params_dict.get('cpp_constructor_args', ''),
cpp_constructor_args=test_params_dict.get("cpp_constructor_args", ""),
arg_dict=compute_arg_dict(test_params_dict, test_instance),
has_parity=test_params_dict.get('has_parity', True),
has_parity=test_params_dict.get("has_parity", True),
device=device,
cpp_tmp_folder=tempfile.mkdtemp(),
)
def write_test_to_test_class(
unit_test_class, test_params_dict, test_instance_class, parity_table, devices):
unit_test_class, test_params_dict, test_instance_class, parity_table, devices
):
assert not is_torch_nn_functional_test(test_params_dict)
module_name = compute_module_name(test_params_dict)
@ -238,11 +300,12 @@ def write_test_to_test_class(
f"or set `module_name` using format `ModuleName` in the module test dict:\n{pprint.pformat(test_params_dict)}"
)
module_full_name = 'torch::nn::' + module_name
module_full_name = "torch::nn::" + module_name
assert module_full_name in parity_table['torch::nn'], (
assert module_full_name in parity_table["torch::nn"], (
f"Please add `{module_full_name}` entry to `torch::nn` section of `test/cpp_api_parity/parity-tracker.md`. "
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)")
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
)
for device in devices:
test_params = process_test_params_for_module(
@ -251,35 +314,43 @@ def write_test_to_test_class(
test_instance_class=test_instance_class,
)
try_remove_folder(test_params.cpp_tmp_folder)
unit_test_name = f'test_torch_nn_{test_params.module_variant_name}'
unit_test_name = f"test_torch_nn_{test_params.module_variant_name}"
unit_test_class.module_test_params_map[unit_test_name] = test_params
def test_fn(self):
test_forward_backward(
unit_test_class=self, test_params=unit_test_class.module_test_params_map[self._testMethodName])
unit_test_class=self,
test_params=unit_test_class.module_test_params_map[
self._testMethodName
],
)
test_fn = decorate_test_fn(
test_fn=test_fn,
test_cuda=test_params_dict.get('test_cuda', True),
has_impl_parity=parity_table['torch::nn'][module_full_name][0] and
test_params_dict.get('has_parity', True),
device=device)
test_cuda=test_params_dict.get("test_cuda", True),
has_impl_parity=parity_table["torch::nn"][module_full_name][0]
and test_params_dict.get("has_parity", True),
device=device,
)
add_test(unit_test_class, unit_test_name, test_fn)
def generate_test_cpp_sources(test_params, template):
device = test_params.device
cpp_constructor_args = test_params.cpp_constructor_args
if cpp_constructor_args != '':
cpp_constructor_args = f'({cpp_constructor_args})'
if cpp_constructor_args != "":
cpp_constructor_args = f"({cpp_constructor_args})"
cpp_args_construction_stmts, cpp_forward_args_symbols = \
compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
(
cpp_args_construction_stmts,
cpp_forward_args_symbols,
) = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
test_cpp_sources = template.substitute(
module_variant_name=test_params.module_variant_name,
module_qualified_name=f'torch::nn::{test_params.module_name}',
module_qualified_name=f"torch::nn::{test_params.module_name}",
cpp_args_construction_stmts=";\n ".join(cpp_args_construction_stmts),
cpp_constructor_args=cpp_constructor_args,
cpp_forward_args_symbols=", ".join(cpp_forward_args_symbols),
@ -287,6 +358,7 @@ def generate_test_cpp_sources(test_params, template):
)
return test_cpp_sources
# Build all C++ tests together, instead of once per test.
def build_cpp_tests(unit_test_class, print_cpp_source=False):
assert len(unit_test_class.module_test_params_map) > 0
@ -294,13 +366,13 @@ def build_cpp_tests(unit_test_class, print_cpp_source=False):
functions = []
for test_params in unit_test_class.module_test_params_map.values():
cpp_sources += generate_test_cpp_sources(
test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD_BACKWARD)
functions.append(f'{test_params.module_variant_name}_test_forward_backward')
test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD_BACKWARD
)
functions.append(f"{test_params.module_variant_name}_test_forward_backward")
if print_cpp_source:
print(cpp_sources)
cpp_module = compile_cpp_code_inline(
name='module_impl_check',
cpp_sources=cpp_sources,
functions=functions)
name="module_impl_check", cpp_sources=cpp_sources, functions=functions
)
unit_test_class.module_impl_check_cpp_module = cpp_module

View File

@ -1,8 +1,8 @@
from collections import namedtuple
ParityStatus = namedtuple('ParityStatus', ['has_impl_parity', 'has_doc_parity'])
ParityStatus = namedtuple("ParityStatus", ["has_impl_parity", "has_doc_parity"])
'''
"""
This function expects the parity tracker Markdown file to have the following format:
```
@ -28,31 +28,39 @@ Dict[package_name]
-> Dict[api_name]
-> ParityStatus
```
'''
"""
def parse_parity_tracker_table(file_path):
def parse_parity_choice(str):
if str in ['Yes', 'No']:
return str == 'Yes'
if str in ["Yes", "No"]:
return str == "Yes"
else:
raise RuntimeError(
f'{str} is not a supported parity choice. The valid choices are "Yes" and "No".')
f'{str} is not a supported parity choice. The valid choices are "Yes" and "No".'
)
parity_tracker_dict = {}
with open(file_path) as f:
all_text = f.read()
packages = all_text.split('##')
packages = all_text.split("##")
for package in packages[1:]:
lines = [line.strip() for line in package.split('\n') if line.strip() != '']
lines = [line.strip() for line in package.split("\n") if line.strip() != ""]
package_name = lines[0]
if package_name in parity_tracker_dict:
raise RuntimeError(f"Duplicated package name `{package_name}` found in {file_path}")
raise RuntimeError(
f"Duplicated package name `{package_name}` found in {file_path}"
)
else:
parity_tracker_dict[package_name] = {}
for api_status in lines[3:]:
api_name, has_impl_parity_str, has_doc_parity_str = (x.strip() for x in api_status.split('|'))
api_name, has_impl_parity_str, has_doc_parity_str = (
x.strip() for x in api_status.split("|")
)
parity_tracker_dict[package_name][api_name] = ParityStatus(
has_impl_parity=parse_parity_choice(has_impl_parity_str),
has_doc_parity=parse_parity_choice(has_doc_parity_str))
has_doc_parity=parse_parity_choice(has_doc_parity_str),
)
return parity_tracker_dict

View File

@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
from torch.testing._internal.common_nn import wrap_functional
'''
"""
`sample_functional` is used by `test_cpp_api_parity.py` to test that Python / C++ API
parity test harness works for `torch.nn.functional` functions.
@ -11,7 +11,8 @@ is the same as the C++ equivalent.
When `has_parity=false` is passed to `sample_functional`, behavior of `sample_functional`
is different from the C++ equivalent.
'''
"""
def sample_functional(x, has_parity):
if has_parity:
@ -19,6 +20,7 @@ def sample_functional(x, has_parity):
else:
return x * 4
torch.nn.functional.sample_functional = sample_functional
SAMPLE_FUNCTIONAL_CPP_SOURCE = """\n
@ -44,16 +46,16 @@ Tensor sample_functional(Tensor x, SampleFunctionalFuncOptions options) {
functional_tests = [
dict(
constructor=wrap_functional(F.sample_functional, has_parity=True),
cpp_options_args='F::SampleFunctionalFuncOptions(true)',
cpp_options_args="F::SampleFunctionalFuncOptions(true)",
input_size=(1, 2, 3),
fullname='sample_functional_has_parity',
fullname="sample_functional_has_parity",
has_parity=True,
),
dict(
constructor=wrap_functional(F.sample_functional, has_parity=False),
cpp_options_args='F::SampleFunctionalFuncOptions(false)',
cpp_options_args="F::SampleFunctionalFuncOptions(false)",
input_size=(1, 2, 3),
fullname='sample_functional_no_parity',
fullname="sample_functional_no_parity",
has_parity=False,
),
# This is to test that setting the `test_cpp_api_parity=False` flag skips
@ -61,9 +63,9 @@ functional_tests = [
# throw a parity error).
dict(
constructor=wrap_functional(F.sample_functional, has_parity=False),
cpp_options_args='F::SampleFunctionalFuncOptions(false)',
cpp_options_args="F::SampleFunctionalFuncOptions(false)",
input_size=(1, 2, 3),
fullname='sample_functional_THIS_TEST_SHOULD_BE_SKIPPED',
fullname="sample_functional_THIS_TEST_SHOULD_BE_SKIPPED",
test_cpp_api_parity=False,
),
]

View File

@ -1,6 +1,6 @@
import torch
'''
"""
`SampleModule` is used by `test_cpp_api_parity.py` to test that Python / C++ API
parity test harness works for `torch.nn.Module` subclasses.
@ -9,7 +9,8 @@ is the same as the C++ equivalent.
When `SampleModule.has_parity` is false, behavior of `forward` / `backward`
is different from the C++ equivalent.
'''
"""
class SampleModule(torch.nn.Module):
def __init__(self, has_parity, has_submodule):
@ -19,7 +20,7 @@ class SampleModule(torch.nn.Module):
self.submodule = SampleModule(self.has_parity, False)
self.has_submodule = has_submodule
self.register_parameter('param', torch.nn.Parameter(torch.empty(3, 4)))
self.register_parameter("param", torch.nn.Parameter(torch.empty(3, 4)))
self.reset_parameters()
@ -28,12 +29,15 @@ class SampleModule(torch.nn.Module):
self.param.fill_(1)
def forward(self, x):
submodule_forward_result = self.submodule(x) if hasattr(self, 'submodule') else 0
submodule_forward_result = (
self.submodule(x) if hasattr(self, "submodule") else 0
)
if self.has_parity:
return x + self.param * 2 + submodule_forward_result
else:
return x + self.param * 4 + submodule_forward_result + 3
torch.nn.SampleModule = SampleModule
SAMPLE_MODULE_CPP_SOURCE = """\n
@ -73,31 +77,31 @@ TORCH_MODULE(SampleModule);
module_tests = [
dict(
module_name='SampleModule',
desc='has_parity',
module_name="SampleModule",
desc="has_parity",
constructor_args=(True, True),
cpp_constructor_args='torch::nn::SampleModuleOptions(true, true)',
cpp_constructor_args="torch::nn::SampleModuleOptions(true, true)",
input_size=(3, 4),
cpp_input_args=['torch::randn({3, 4})'],
cpp_input_args=["torch::randn({3, 4})"],
has_parity=True,
),
dict(
fullname='SampleModule_no_parity',
fullname="SampleModule_no_parity",
constructor=lambda: SampleModule(has_parity=False, has_submodule=True),
cpp_constructor_args='torch::nn::SampleModuleOptions(false, true)',
cpp_constructor_args="torch::nn::SampleModuleOptions(false, true)",
input_size=(3, 4),
cpp_input_args=['torch::randn({3, 4})'],
cpp_input_args=["torch::randn({3, 4})"],
has_parity=False,
),
# This is to test that setting the `test_cpp_api_parity=False` flag skips
# the C++ API parity test accordingly (otherwise this test would run and
# throw a parity error).
dict(
fullname='SampleModule_THIS_TEST_SHOULD_BE_SKIPPED',
fullname="SampleModule_THIS_TEST_SHOULD_BE_SKIPPED",
constructor=lambda: SampleModule(False, True),
cpp_constructor_args='torch::nn::SampleModuleOptions(false, true)',
cpp_constructor_args="torch::nn::SampleModuleOptions(false, true)",
input_size=(3, 4),
cpp_input_args=['torch::randn({3, 4})'],
cpp_input_args=["torch::randn({3, 4})"],
test_cpp_api_parity=False,
),
]

View File

@ -1,37 +1,33 @@
from collections import namedtuple
import unittest
import os
import warnings
import shutil
import unittest
import warnings
from collections import namedtuple
import torch
import torch.utils.cpp_extension
import torch.testing._internal.common_nn as common_nn
import torch.utils.cpp_extension
from torch.testing._internal.common_cuda import TEST_CUDA
# Note that this namedtuple is for C++ parity test mechanism's internal use.
# For guidance on how to add a new C++ parity test, please see
# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
TorchNNModuleTestParams = namedtuple(
'TorchNNModuleTestParams',
"TorchNNModuleTestParams",
[
# NN module name (e.g. "BCELoss")
'module_name',
"module_name",
# Unique identifier for this module config (e.g. "BCELoss_weights_cuda")
'module_variant_name',
"module_variant_name",
# An instance of an NN test class (e.g. `CriterionTest`) which stores
# necessary information (e.g. input / target / extra_args) for running the Python test
'test_instance',
"test_instance",
# Constructor arguments passed to the C++ module constructor, which must be
# strictly equivalent to the Python module constructor arguments
# (e.g. `torch::nn::BCELossOptions().weight(torch::rand(10))`,
# which is strictly equivalent to passing `torch.rand(10)` to `torch.nn.BCELoss`
# constructor in Python)
'cpp_constructor_args',
"cpp_constructor_args",
# All arguments used in NN module's forward pass.
# Please see `compute_arg_dict` function for details on how we construct this dict.
# (e.g.
@ -44,42 +40,35 @@ TorchNNModuleTestParams = namedtuple(
# }
# ```
# )
'arg_dict',
"arg_dict",
# Whether we expect this NN module test to pass the Python/C++ parity test
# (e.g. `True`)
'has_parity',
"has_parity",
# Device (e.g. "cuda")
'device',
"device",
# Temporary folder to store C++ outputs (to be compared with Python outputs later)
'cpp_tmp_folder',
]
"cpp_tmp_folder",
],
)
# Note that this namedtuple is for C++ parity test mechanism's internal use.
# For guidance on how to add a new C++ parity test, please see
# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
TorchNNFunctionalTestParams = namedtuple(
'TorchNNFunctionalTestParams',
"TorchNNFunctionalTestParams",
[
# NN functional name (e.g. "binary_cross_entropy")
'functional_name',
"functional_name",
# Unique identifier for this functional config (e.g. "BCELoss_no_reduce_cuda")
'functional_variant_name',
"functional_variant_name",
# An instance of an NN test class (e.g. `NewModuleTest`) which stores
# necessary information (e.g. input / target / extra_args) for running the Python test
'test_instance',
"test_instance",
# The C++ function call that is strictly equivalent to the Python function call
# (e.g. "F::binary_cross_entropy(
# i, t.to(i.options()),F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))",
# which is strictly equivalent to `F.binary_cross_entropy(i, t.type_as(i), reduction='none')` in Python)
'cpp_function_call',
"cpp_function_call",
# All arguments used in NN functional's function call.
# Please see `compute_arg_dict` function for details on how we construct this dict.
# (e.g.
@ -92,21 +81,18 @@ TorchNNFunctionalTestParams = namedtuple(
# }
# ```
# )
'arg_dict',
"arg_dict",
# Whether we expect this NN functional test to pass the Python/C++ parity test
# (e.g. `True`)
'has_parity',
"has_parity",
# Device (e.g. "cuda")
'device',
"device",
# Temporary folder to store C++ outputs (to be compared with Python outputs later)
'cpp_tmp_folder',
]
"cpp_tmp_folder",
],
)
CppArg = namedtuple('CppArg', ['name', 'value'])
CppArg = namedtuple("CppArg", ["name", "value"])
TORCH_NN_COMMON_TEST_HARNESS = """
#include <torch/script.h>
@ -139,21 +125,27 @@ torch::Tensor _rand_tensor_non_equal(torch::IntArrayRef size) {
}
"""
def compile_cpp_code_inline(name, cpp_sources, functions):
cpp_module = torch.utils.cpp_extension.load_inline(
name=name,
cpp_sources=cpp_sources,
extra_cflags=['-g'], # Enable debug symbols by default for debugging test failures.
extra_cflags=[
"-g"
], # Enable debug symbols by default for debugging test failures.
functions=functions,
verbose=False,
)
return cpp_module
def compute_temp_file_path(cpp_tmp_folder, variant_name, file_suffix):
return os.path.join(cpp_tmp_folder, f'{variant_name}_{file_suffix}.pt')
return os.path.join(cpp_tmp_folder, f"{variant_name}_{file_suffix}.pt")
def is_torch_nn_functional_test(test_params_dict):
return 'wrap_functional' in str(test_params_dict.get('constructor', ''))
return "wrap_functional" in str(test_params_dict.get("constructor", ""))
def convert_to_list(python_input):
if isinstance(python_input, torch.Tensor):
@ -161,31 +153,46 @@ def convert_to_list(python_input):
else:
return list(python_input)
def set_python_tensors_requires_grad(python_tensors):
return [tensor.requires_grad_(True) if tensor.dtype != torch.long else tensor for tensor in python_tensors]
return [
tensor.requires_grad_(True) if tensor.dtype != torch.long else tensor
for tensor in python_tensors
]
def move_python_tensors_to_device(python_tensors, device):
return [tensor.to(device) for tensor in python_tensors]
def has_test(unit_test_class, test_name):
return hasattr(unit_test_class, test_name)
def add_test(unit_test_class, test_name, test_fn):
if has_test(unit_test_class, test_name):
raise RuntimeError("Found two tests with the same name: " + test_name)
setattr(unit_test_class, test_name, test_fn)
def set_cpp_tensors_requires_grad(cpp_tensor_stmts, python_tensors):
assert len(cpp_tensor_stmts) == len(python_tensors)
return [f'{tensor_stmt}.requires_grad_(true)' if tensor.dtype != torch.long else tensor_stmt
for tensor_stmt, (_, tensor) in zip(cpp_tensor_stmts, python_tensors)]
return [
f"{tensor_stmt}.requires_grad_(true)"
if tensor.dtype != torch.long
else tensor_stmt
for tensor_stmt, (_, tensor) in zip(cpp_tensor_stmts, python_tensors)
]
def move_cpp_tensors_to_device(cpp_tensor_stmts, device):
return [f'{tensor_stmt}.to("{device}")' for tensor_stmt in cpp_tensor_stmts]
def is_criterion_test(test_instance):
return isinstance(test_instance, common_nn.CriterionTest)
# This function computes the following:
# - What variable declaration statements should show up in the C++ parity test function
# - What arguments should be passed into the C++ module/functional's forward function
@ -213,26 +220,42 @@ def compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params):
cpp_forward_args_symbols.append(arg_name)
return args_stmts
cpp_forward_input_args_stmts = set_cpp_tensors_requires_grad(move_cpp_tensors_to_device(
add_cpp_forward_args(test_params.arg_dict['input']), device), test_params.arg_dict['input'])
cpp_forward_input_args_stmts = set_cpp_tensors_requires_grad(
move_cpp_tensors_to_device(
add_cpp_forward_args(test_params.arg_dict["input"]), device
),
test_params.arg_dict["input"],
)
cpp_forward_target_args_stmts = move_cpp_tensors_to_device(
add_cpp_forward_args(test_params.arg_dict['target']), device)
add_cpp_forward_args(test_params.arg_dict["target"]), device
)
cpp_forward_extra_args_stmts = move_cpp_tensors_to_device(
add_cpp_forward_args(test_params.arg_dict['extra_args']), device)
add_cpp_forward_args(test_params.arg_dict["extra_args"]), device
)
# Build the list of other arguments needed
cpp_other_args_stmts = []
for arg_name, _ in test_params.arg_dict['other']:
for arg_name, _ in test_params.arg_dict["other"]:
cpp_other_args_stmts.append(f'auto {arg_name} = arg_dict.at("{arg_name}")')
cpp_other_args_stmts = move_cpp_tensors_to_device(cpp_other_args_stmts, device)
cpp_args_construction_stmts = cpp_forward_input_args_stmts + cpp_forward_target_args_stmts + \
cpp_forward_extra_args_stmts + cpp_other_args_stmts
cpp_args_construction_stmts = (
cpp_forward_input_args_stmts
+ cpp_forward_target_args_stmts
+ cpp_forward_extra_args_stmts
+ cpp_other_args_stmts
)
return cpp_args_construction_stmts, cpp_forward_args_symbols
def serialize_arg_dict_as_script_module(arg_dict):
arg_dict_flat = dict(arg_dict['input'] + arg_dict['target'] + arg_dict['extra_args'] + arg_dict['other'])
arg_dict_flat = dict(
arg_dict["input"]
+ arg_dict["target"]
+ arg_dict["extra_args"]
+ arg_dict["other"]
)
arg_dict_module = torch.nn.Module()
for arg_name, arg_value in arg_dict_flat.items():
assert isinstance(arg_value, torch.Tensor)
@ -240,6 +263,7 @@ def serialize_arg_dict_as_script_module(arg_dict):
return torch.jit.script(arg_dict_module)
# NOTE: any argument symbol used in `cpp_constructor_args` / `cpp_options_args` / `cpp_function_call`
# must have a mapping in `cpp_var_map`.
#
@ -270,38 +294,47 @@ def serialize_arg_dict_as_script_module(arg_dict):
# ```
def compute_arg_dict(test_params_dict, test_instance):
arg_dict = {
'input': [],
'target': [],
'extra_args': [],
'other': [],
"input": [],
"target": [],
"extra_args": [],
"other": [],
}
def put_args_into_arg_dict(arg_type, arg_type_prefix, args):
for i, arg in enumerate(args):
arg_dict[arg_type].append(CppArg(name=arg_type_prefix + str(i), value=arg))
put_args_into_arg_dict('input', 'i', convert_to_list(test_instance._get_input()))
put_args_into_arg_dict("input", "i", convert_to_list(test_instance._get_input()))
if is_criterion_test(test_instance):
put_args_into_arg_dict('target', 't', convert_to_list(test_instance._get_target()))
put_args_into_arg_dict(
"target", "t", convert_to_list(test_instance._get_target())
)
if test_instance.extra_args:
put_args_into_arg_dict('extra_args', 'e', convert_to_list(test_instance.extra_args))
put_args_into_arg_dict(
"extra_args", "e", convert_to_list(test_instance.extra_args)
)
cpp_var_map = test_params_dict.get('cpp_var_map', {})
cpp_var_map = test_params_dict.get("cpp_var_map", {})
for arg_name, arg_value in cpp_var_map.items():
if isinstance(arg_value, str):
if arg_value == '_get_input()':
arg_dict['other'].append(CppArg(name=arg_name, value=test_instance._get_input()))
if arg_value == "_get_input()":
arg_dict["other"].append(
CppArg(name=arg_name, value=test_instance._get_input())
)
else:
raise RuntimeError(f"`{arg_name}` has unsupported string value: {arg_value}")
raise RuntimeError(
f"`{arg_name}` has unsupported string value: {arg_value}"
)
elif isinstance(arg_value, torch.Tensor):
arg_dict['other'].append(CppArg(name=arg_name, value=arg_value))
arg_dict["other"].append(CppArg(name=arg_name, value=arg_value))
else:
raise RuntimeError(f"`{arg_name}` has unsupported value: {arg_value}")
return arg_dict
def decorate_test_fn(test_fn, test_cuda, has_impl_parity, device):
if device == 'cuda':
if device == "cuda":
test_fn = unittest.skipIf(not TEST_CUDA, "CUDA unavailable")(test_fn)
test_fn = unittest.skipIf(not test_cuda, "Excluded from CUDA tests")(test_fn)
@ -313,7 +346,8 @@ def decorate_test_fn(test_fn, test_cuda, has_impl_parity, device):
return test_fn
MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE = '''
MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE = """
What should I do when C++ API parity test is failing?
- If you are changing the implementation of an existing `torch.nn` module / `torch.nn.functional` function:
@ -337,13 +371,15 @@ also file an issue when you do this.
For more details on how to add a C++ API parity test, please see:
NOTE [How to check NN module / functional API parity between Python and C++ frontends]
'''
"""
def generate_error_msg(name, cpp_value, python_value):
return (
"Parity test failed: {} in C++ has value: {}, "
"which does not match the corresponding value in Python: {}.\n{}").format(
name, cpp_value, python_value, MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE)
"which does not match the corresponding value in Python: {}.\n{}"
).format(name, cpp_value, python_value, MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE)
def try_remove_folder(folder_path):
if os.path.exists(folder_path):
@ -351,4 +387,6 @@ def try_remove_folder(folder_path):
try:
shutil.rmtree(folder_path)
except Exception as e:
warnings.warn(f"Non-blocking folder removal fails with the following error:\n{str(e)}")
warnings.warn(
f"Non-blocking folder removal fails with the following error:\n{str(e)}"
)