C++ parity, convert_parameters (#29267)

Summary:
yf225 https://github.com/pytorch/pytorch/issues/25883
update parameters_to_vector and vector_to_parameters
check please!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29267

Differential Revision: D18628571

Pulled By: yf225

fbshipit-source-id: 03783e6b0f8183dd97ae48f3da4acb1d07083555
This commit is contained in:
lsrock1
2019-11-20 19:57:27 -08:00
committed by Facebook Github Bot
parent bbb3c415c9
commit 0a77c090d5
4 changed files with 127 additions and 2 deletions

View File

@ -140,3 +140,50 @@ TEST_F(NNUtilsTest, ClipGradValue) {
utils::clip_grad_value_(params, clip_value);
ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
}
TEST_F(NNUtilsTest, ConvertParameters) {
std::vector<torch::Tensor> parameters{
torch::arange(9, torch::kFloat32),
torch::arange(9, torch::kFloat32).view({3, 3}),
torch::arange(8, torch::kFloat32).view({2, 2, 2})
};
auto expected = torch::cat({
torch::arange(9, torch::kFloat32),
torch::arange(9, torch::kFloat32).view(-1),
torch::arange(8, torch::kFloat32).view(-1)
});
auto vector = utils::parameters_to_vector(parameters);
ASSERT_TRUE(vector.allclose(expected));
std::vector<torch::Tensor> zero_parameters{
torch::zeros({9}, torch::kFloat32),
torch::zeros({9}, torch::kFloat32).view({3, 3}),
torch::zeros({8}, torch::kFloat32).view({2, 2, 2})
};
utils::vector_to_parameters(vector, zero_parameters);
for (int i = 0; i < zero_parameters.size(); ++i) {
ASSERT_TRUE(zero_parameters[i].allclose(parameters[i]));
}
{
auto conv1 = Conv2d(3, 10, 5);
auto fc1 = Linear(10, 20);
auto model = Sequential(conv1, fc1);
auto vec = utils::parameters_to_vector(model->parameters());
ASSERT_EQ(vec.size(0), 980);
}
{
auto conv1 = Conv2d(3, 10, 5);
auto fc1 = Linear(10, 20);
auto model = Sequential(conv1, fc1);
auto vec = torch::arange(0., 980);
utils::vector_to_parameters(vec, model->parameters());
auto sample = model->parameters()[0][0][0][0];
ASSERT_TRUE(torch::equal(sample.data(), vec.data().slice(0, 0, 5)));
}
}

View File

@ -129,8 +129,8 @@ torch::nn::DataParallel|No|No
torch::nn::parallel::DistributedDataParallel|No|No
torch::nn::utils::clip_grad_norm_|Yes|No
torch::nn::utils::clip_grad_value_|Yes|No
torch::nn::utils::parameters_to_vector|No|No
torch::nn::utils::vector_to_parameters|No|No
torch::nn::utils::parameters_to_vector|Yes|No
torch::nn::utils::vector_to_parameters|Yes|No
torch::nn::utils::weight_norm|No|No
torch::nn::utils::remove_weight_norm|No|No
torch::nn::utils::spectral_norm|No|No

View File

@ -1,3 +1,4 @@
#pragma once
#include <torch/nn/utils/clip_grad.h>
#include <torch/nn/utils/convert_parameters.h>

View File

@ -0,0 +1,77 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/types.h>
namespace torch {
namespace nn {
namespace utils {
// This helper function is to check if the parameters are located
// in the same device. Currently, the conversion between model parameters
// and single vector form is not supported for multiple allocations,
// e.g. parameters in different GPUs, or mixture of CPU/GPU.
inline c10::optional<int64_t> _check_param_device(const torch::Tensor& param, c10::optional<int64_t> old_param_device) {
// Meet the first parameter
if (old_param_device == c10::nullopt) {
old_param_device = param.is_cuda() ? param.get_device() : -1;
}
else {
bool warn = false;
if (param.is_cuda()) { // Check if in same GPU
warn = (param.get_device() != old_param_device.value());
}
else { // Check if in CPU
warn = (old_param_device.value() != -1);
}
if (warn) {
TORCH_CHECK(false, "Found two parameters on different devices, ",
"this is currently not supported.");
}
}
return old_param_device;
}
// Convert parameters to one vector
inline torch::Tensor parameters_to_vector(const std::vector<torch::Tensor>& parameters) {
c10::optional<int64_t> param_device;
std::vector<torch::Tensor> vec;
vec.reserve(parameters.size());
for (const torch::Tensor& param : parameters) {
// Ensure the parameters are located in the same device
param_device = _check_param_device(param, param_device);
vec.push_back(param.view(-1));
}
return torch::cat(vec);
}
// Convert one vector to the parameters
inline void vector_to_parameters(const torch::Tensor& vec, std::vector<torch::Tensor> parameters) {
// Flag for the device where the parameter is located
c10::optional<int64_t> param_device;
// Pointer for slicing the vector for each parameter
int64_t pointer = 0;
int64_t num_param;
for (torch::Tensor& param : parameters) {
// Ensure the parameters are located in the same device
param_device = _check_param_device(param, param_device);
// The length of the parameter
num_param = param.numel();
// Slice the vector, reshape it, and replace the old data of the parameter
param.set_data(vec.slice(0, pointer, pointer + num_param).view_as(param).data());
// Increment the pointer
pointer += num_param;
}
}
} // namespace utils
} // namespace nn
} // namespace torch