Files
pytorch/test/cpp_api_parity/module_impl_check.py
Mikayla Gawarecki d9576c9440 Fix failures when default is flipped for weights_only (#127627)
Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627
Approved by: https://github.com/albanD
ghstack dependencies: #132349
2024-08-16 00:22:43 +00:00

382 lines
14 KiB
Python

# The purpose of this test is to check that we have implementation parity between
# a Python `torch.nn` module and its corresponding C++ `torch::nn` module. Concretely,
# this test does the following:
#
# 1. Get a test params dict from common_nn.py, run forward and backward on the
# Python module created using the test params.
#
# 2. Serialize the Python module's parameters / buffers and its forward input
# arguments, deserialize them in C++ and load them into the C++ module.
#
# 3. Run the same forward and backward passes on the C++ module, and serialize
# the C++ module's forward output and backward gradients.
#
# 4. Compare Python/C++ module's forward output and backward gradients. If they
# are the same, then we have implementation parity between Python/C++ module.
import os
import pprint
import tempfile
import types
from string import Template
import torch
from cpp_api_parity.sample_module import SAMPLE_MODULE_CPP_SOURCE
from cpp_api_parity.utils import (
add_test,
compile_cpp_code_inline,
compute_arg_dict,
compute_cpp_args_construction_stmts_and_forward_arg_symbols,
compute_temp_file_path,
decorate_test_fn,
generate_error_msg,
is_torch_nn_functional_test,
move_python_tensors_to_device,
serialize_arg_dict_as_script_module,
set_python_tensors_requires_grad,
TORCH_NN_COMMON_TEST_HARNESS,
TorchNNModuleTestParams,
try_remove_folder,
)
from torch.jit._pickle import restore_type_tag
# Expected substitutions:
#
# ${module_variant_name} (e.g. `Linear_no_bias_cpu`)
# ${module_qualified_name} (e.g. `torch::nn::Linear`)
# ${cpp_args_construction_stmts}
# ${cpp_constructor_args}
# ${device}
# ${cpp_forward_args_symbols}
TORCH_NN_MODULE_TEST_FORWARD_BACKWARD = Template(
"""
void ${module_variant_name}_test_forward_backward(
const std::string& arg_dict_file_path,
const std::string& module_file_path,
const std::string& forward_output_file_path,
const std::string& backward_grad_dict_file_path) {
pybind11::gil_scoped_release no_gil;
// Declare arguments
auto arg_dict = load_dict_from_file(arg_dict_file_path);
${cpp_args_construction_stmts};
// Construct module and load params/buffers from Python module
${module_qualified_name} module${cpp_constructor_args};
module->to(std::string("${device}"));
torch::load(module, module_file_path);
// Some modules (such as `RReLU`) create random tensors in their forward pass.
// To make sure the random tensors created are the same in Python/C++, we need
// to set the RNG seed manually.
torch::manual_seed(0);
// Forward pass
auto cpp_output = module(${cpp_forward_args_symbols});
// Save the output into a file to be compared in Python later
write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path);
// Backward pass
if (cpp_output.is_complex()) {
cpp_output.sum().abs().backward();
} else {
cpp_output.sum().backward();
}
// Put all gradients into a c10::Dict, save it into a file to be compared in Python later
c10::Dict<std::string, torch::Tensor> grad_dict;
for (const auto& param : module->named_parameters()) {
torch::Tensor grad = param.value().grad();
if (grad.is_sparse()) {
grad_dict.insert(param.key() + "_grad_indices", grad.coalesce().indices());
grad_dict.insert(param.key() + "_grad_values", grad.coalesce().values());
} else {
grad_dict.insert(param.key() + "_grad", grad);
}
}
write_ivalue_to_file(torch::IValue(grad_dict), backward_grad_dict_file_path);
}
"""
)
def run_python_forward_backward(unit_test_class, test_params):
device = test_params.device
module = test_params.test_instance.constructor(
*test_params.test_instance.constructor_args
).to(device)
inputs = set_python_tensors_requires_grad(
move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict["input"]], device
)
)
inputs += move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict["target"]], device
)
inputs += move_python_tensors_to_device(
[arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device
)
# Some modules (such as `RReLU`) create random tensors in their forward pass.
# To make sure the random tensors created are the same in Python/C++, we need
# to set the RNG seed manually.
torch.manual_seed(0)
# Forward pass
python_output = module(*inputs)
# NOTE: This is a workaround to allow any module to be traced.
# We can do this because we are only interested in transferring
# the Python module's parameters and buffers to the C++ module.
module.forward = types.MethodType(lambda self, input: input, module)
script_module = torch.jit.trace(module, torch.tensor(0))
# Backward pass
if python_output.dtype.is_complex:
python_output.sum().abs().backward()
else:
python_output.sum().backward()
# Put all gradients into a dict, to be compared later
python_grad_dict = {}
for name, param in module.named_parameters():
grad = param.grad
if grad.is_sparse:
python_grad_dict[name + "_grad_indices"] = grad.coalesce().indices()
python_grad_dict[name + "_grad_values"] = grad.coalesce().values()
else:
python_grad_dict[name + "_grad"] = grad
return script_module, python_output, python_grad_dict
def test_forward_backward(unit_test_class, test_params):
module_variant_name = test_params.module_variant_name
cpp_tmp_folder = test_params.cpp_tmp_folder
# Remove the temporary folder if it exists already
try_remove_folder(cpp_tmp_folder)
os.mkdir(cpp_tmp_folder)
# Run forward and backward on Python module
script_module, python_output, python_grad_dict = run_python_forward_backward(
unit_test_class, test_params
)
# Save Python module and arguments to be used from C++ function
module_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "module"
)
arg_dict_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "arg_dict"
)
script_module.save(module_file_path)
serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path)
cpp_test_name = f"{test_params.module_variant_name}_test_forward_backward"
cpp_test_fn = getattr(unit_test_class.module_impl_check_cpp_module, cpp_test_name)
def run_cpp_test_fn_and_check_output():
forward_output_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "forward_output"
)
backward_grad_dict_file_path = compute_temp_file_path(
cpp_tmp_folder, module_variant_name, "backward_grad_dict"
)
cpp_test_fn(
arg_dict_file_path,
module_file_path,
forward_output_file_path,
backward_grad_dict_file_path,
)
cpp_output = torch.load(forward_output_file_path)
# weights_only: need GLOBAL torch.jit._pickle.restore_type_tag
with torch.serialization.safe_globals([restore_type_tag]):
cpp_grad_dict = torch.load(backward_grad_dict_file_path)
# Check that forward outputs are equal
unit_test_class.assertEqual(
python_output,
cpp_output,
msg=generate_error_msg("forward output", cpp_output, python_output),
)
# Check that module parameter gradients are equal after backward pass
unit_test_class.assertEqual(
len(python_grad_dict),
len(cpp_grad_dict),
msg=generate_error_msg(
"# of parameters", len(cpp_grad_dict), len(python_grad_dict)
),
)
for key in python_grad_dict:
param_name = None
for suffix in ["_grad", "_grad_indices", "_grad_values"]:
if key.endswith(suffix):
param_name = key[: -len(suffix)]
break
assert param_name is not None
sparsity_str = (
"sparse" if key.endswith(("_grad_indices", "_grad_values")) else "dense"
)
unit_test_class.assertTrue(
key in cpp_grad_dict,
msg=generate_error_msg(
f'"Does module have a parameter named `{param_name}` with {sparsity_str} gradient?"',
False,
True,
),
)
unit_test_class.assertEqual(
python_grad_dict[key],
cpp_grad_dict[key],
msg=generate_error_msg(
f"`{param_name}`'s {sparsity_str} gradient (`{key}`)",
cpp_grad_dict[key],
python_grad_dict[key],
),
)
run_cpp_test_fn_and_check_output()
# Remove temporary folder that stores C++ outputs
try_remove_folder(cpp_tmp_folder)
def compute_module_name(test_params_dict):
fullname = test_params_dict.get("fullname", None)
if fullname:
module_name = fullname.split("_")[0]
else:
module_name = test_params_dict.get("module_name")
return module_name
def process_test_params_for_module(test_params_dict, device, test_instance_class):
module_name = compute_module_name(test_params_dict)
test_params_dict["constructor"] = test_params_dict.get(
"constructor", getattr(torch.nn, module_name)
)
test_instance = test_instance_class(**test_params_dict)
assert test_instance.get_name().startswith("test_")
# Example output: `BCELoss_weights_cuda`
module_variant_name = test_instance.get_name()[5:] + (
("_" + device) if device != "cpu" else ""
)
if "constructor_args" in test_params_dict:
assert "cpp_constructor_args" in test_params_dict, (
"If `constructor_args` is present in test params dict, to enable C++ API parity test, "
f"`cpp_constructor_args` must be present in:\n{pprint.pformat(test_params_dict)}"
"If you are interested in adding the C++ API parity test, please see:\n"
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
)
return TorchNNModuleTestParams(
module_name=module_name,
module_variant_name=module_variant_name,
test_instance=test_instance,
cpp_constructor_args=test_params_dict.get("cpp_constructor_args", ""),
arg_dict=compute_arg_dict(test_params_dict, test_instance),
has_parity=test_params_dict.get("has_parity", True),
device=device,
cpp_tmp_folder=tempfile.mkdtemp(),
)
def write_test_to_test_class(
unit_test_class, test_params_dict, test_instance_class, parity_table, devices
):
assert not is_torch_nn_functional_test(test_params_dict)
module_name = compute_module_name(test_params_dict)
assert hasattr(torch.nn, module_name), (
f"`torch.nn` doesn't have module `{module_name}`. "
"If you are adding a new test, please set `fullname` using format `ModuleName_desc` "
f"or set `module_name` using format `ModuleName` in the module test dict:\n{pprint.pformat(test_params_dict)}"
)
module_full_name = "torch::nn::" + module_name
assert module_full_name in parity_table["torch::nn"], (
f"Please add `{module_full_name}` entry to `torch::nn` section of `test/cpp_api_parity/parity-tracker.md`. "
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
)
for device in devices:
test_params = process_test_params_for_module(
test_params_dict=test_params_dict,
device=device,
test_instance_class=test_instance_class,
)
try_remove_folder(test_params.cpp_tmp_folder)
unit_test_name = f"test_torch_nn_{test_params.module_variant_name}"
unit_test_class.module_test_params_map[unit_test_name] = test_params
def test_fn(self):
test_forward_backward(
unit_test_class=self,
test_params=unit_test_class.module_test_params_map[
self._testMethodName
],
)
test_fn = decorate_test_fn(
test_fn=test_fn,
test_cuda=test_params_dict.get("test_cuda", True),
has_impl_parity=parity_table["torch::nn"][module_full_name][0]
and test_params_dict.get("has_parity", True),
device=device,
)
add_test(unit_test_class, unit_test_name, test_fn)
def generate_test_cpp_sources(test_params, template):
device = test_params.device
cpp_constructor_args = test_params.cpp_constructor_args
if cpp_constructor_args != "":
cpp_constructor_args = f"({cpp_constructor_args})"
(
cpp_args_construction_stmts,
cpp_forward_args_symbols,
) = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
test_cpp_sources = template.substitute(
module_variant_name=test_params.module_variant_name,
module_qualified_name=f"torch::nn::{test_params.module_name}",
cpp_args_construction_stmts=";\n ".join(cpp_args_construction_stmts),
cpp_constructor_args=cpp_constructor_args,
cpp_forward_args_symbols=", ".join(cpp_forward_args_symbols),
device=device,
)
return test_cpp_sources
# Build all C++ tests together, instead of once per test.
def build_cpp_tests(unit_test_class, print_cpp_source=False):
assert len(unit_test_class.module_test_params_map) > 0
cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_MODULE_CPP_SOURCE
functions = []
for test_params in unit_test_class.module_test_params_map.values():
cpp_sources += generate_test_cpp_sources(
test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD_BACKWARD
)
functions.append(f"{test_params.module_variant_name}_test_forward_backward")
if print_cpp_source:
print(cpp_sources)
cpp_module = compile_cpp_code_inline(
name="module_impl_check", cpp_sources=cpp_sources, functions=functions
)
unit_test_class.module_impl_check_cpp_module = cpp_module