mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
Summary: This PR makes the following improvements to C++ API parity test harness: 1. Remove `options_args` since we can get the list of options from the Python module constructor args. 2. Add test for mapping `int` or `tuple` in Python module constructor args to `ExpandingArray` in C++ module options. 3. Use regex to split up e.g. `(1, {2, 3}, 4)` into `['1', '{2, 3}', '4']` for `cpp_default_constructor_args`. 4. Add options arg accessor tests in `_test_torch_nn_module_ctor_args`. We will be able to merge https://github.com/pytorch/pytorch/pull/24160 and https://github.com/pytorch/pytorch/pull/24860 after these improvements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25828 Differential Revision: D17266197 Pulled By: yf225 fbshipit-source-id: 96d0d4a2fcc4b47cd1782d4df2c9bac107dec3f9
123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
import torch
|
|
|
|
from cpp_api_parity import torch_nn_modules, TorchNNModuleMetadata
|
|
|
|
'''
|
|
`SampleModule` is used by `test_cpp_api_parity.py` to test that Python / C++ API
|
|
parity test harness works for `torch.nn.Module` subclasses.
|
|
|
|
When `SampleModule.has_parity` is true, behavior of `reset_parameters` / `forward` /
|
|
`backward` is the same as the C++ equivalent.
|
|
|
|
When `SampleModule.has_parity` is false, behavior of `reset_parameters` / `forward` /
|
|
`backward` is different from the C++ equivalent.
|
|
'''
|
|
|
|
class SampleModule(torch.nn.Module):
|
|
def __init__(self, has_parity, has_submodule, int_option=0, double_option=0.1,
|
|
bool_option=False, string_option='0', tensor_option=torch.zeros(1),
|
|
int_or_tuple_option=0):
|
|
super(SampleModule, self).__init__()
|
|
self.has_parity = has_parity
|
|
if has_submodule:
|
|
self.submodule = SampleModule(self.has_parity, False)
|
|
|
|
# The following attributes will be included in the `num_attrs_recursive` count.
|
|
self.has_submodule = has_submodule
|
|
self.int_option = int_option
|
|
self.double_option = double_option
|
|
self.bool_option = bool_option
|
|
self.string_option = string_option
|
|
self.tensor_option = tensor_option
|
|
self.int_or_tuple_option = int_or_tuple_option
|
|
self.register_parameter('param', torch.nn.Parameter(torch.empty(3, 4)))
|
|
self.register_buffer('buffer', torch.empty(4, 5))
|
|
self.attr = 0
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
with torch.no_grad():
|
|
self.param.fill_(1)
|
|
self.buffer.fill_(1)
|
|
self.attr = 10
|
|
if not self.has_parity:
|
|
self.param.add_(10)
|
|
self.buffer.add_(10)
|
|
self.attr += 90
|
|
|
|
def forward(self, x):
|
|
submodule_forward_result = self.submodule(x) if hasattr(self, 'submodule') else 0
|
|
if not self.has_parity:
|
|
return x + self.param * 4 + submodule_forward_result + 3
|
|
else:
|
|
return x + self.param * 2 + submodule_forward_result
|
|
|
|
SAMPLE_MODULE_CPP_SOURCE = """\n
|
|
namespace torch {
|
|
namespace nn{
|
|
struct C10_EXPORT SampleModuleOptions {
|
|
SampleModuleOptions(bool has_submodule) : has_submodule_(has_submodule) {}
|
|
TORCH_ARG(bool, has_submodule);
|
|
TORCH_ARG(int64_t, int_option) = 0;
|
|
TORCH_ARG(double, double_option) = 0.1;
|
|
TORCH_ARG(bool, bool_option) = false;
|
|
TORCH_ARG(std::string, string_option) = "0";
|
|
TORCH_ARG(torch::Tensor, tensor_option) = torch::zeros({1});
|
|
TORCH_ARG(ExpandingArray<2>, int_or_tuple_option) = 0;
|
|
};
|
|
|
|
struct C10_EXPORT SampleModuleImpl : public torch::nn::Cloneable<SampleModuleImpl> {
|
|
SampleModuleImpl(bool has_submodule) : SampleModuleImpl(SampleModuleOptions(has_submodule)) {}
|
|
explicit SampleModuleImpl(SampleModuleOptions options) : options(std::move(options)) {
|
|
if (options.has_submodule_) {
|
|
submodule = register_module("submodule", std::make_shared<SampleModuleImpl>(false));
|
|
}
|
|
reset();
|
|
}
|
|
void reset() {
|
|
attr = 10;
|
|
param = register_parameter("param", torch::ones({3, 4}));
|
|
buffer = register_buffer("buffer", torch::ones({4, 5}));
|
|
}
|
|
torch::Tensor forward(torch::Tensor x) {
|
|
return x + param * 2 + (submodule ? submodule->forward(x) : torch::zeros_like(x));
|
|
}
|
|
SampleModuleOptions options;
|
|
torch::Tensor param;
|
|
torch::Tensor buffer;
|
|
int attr;
|
|
std::shared_ptr<SampleModuleImpl> submodule{nullptr};
|
|
};
|
|
|
|
TORCH_MODULE(SampleModule);
|
|
}
|
|
}
|
|
"""
|
|
|
|
module_tests = [
|
|
dict(
|
|
module_name='SampleModule',
|
|
desc='has_parity',
|
|
constructor_args=(True, True),
|
|
cpp_constructor_args='(true)',
|
|
input_size=(3, 4),
|
|
has_parity=True,
|
|
),
|
|
dict(
|
|
fullname='SampleModule_no_parity',
|
|
constructor=lambda: SampleModule(False, True),
|
|
cpp_constructor_args='(true)',
|
|
input_size=(3, 4),
|
|
has_parity=False,
|
|
),
|
|
]
|
|
|
|
torch_nn_modules.module_metadata_map['SampleModule'] = TorchNNModuleMetadata(
|
|
cpp_default_constructor_args='(true)',
|
|
num_attrs_recursive=20,
|
|
cpp_sources=SAMPLE_MODULE_CPP_SOURCE,
|
|
)
|
|
|
|
torch.nn.SampleModule = SampleModule
|