mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
C++ parity, convert_parameters (#29267)
Summary: yf225 https://github.com/pytorch/pytorch/issues/25883 update parameters_to_vector and vector_to_parameters check please! Pull Request resolved: https://github.com/pytorch/pytorch/pull/29267 Differential Revision: D18628571 Pulled By: yf225 fbshipit-source-id: 03783e6b0f8183dd97ae48f3da4acb1d07083555
This commit is contained in:
committed by
Facebook Github Bot
parent
bbb3c415c9
commit
0a77c090d5
@ -140,3 +140,50 @@ TEST_F(NNUtilsTest, ClipGradValue) {
|
||||
utils::clip_grad_value_(params, clip_value);
|
||||
ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
|
||||
}
|
||||
|
||||
TEST_F(NNUtilsTest, ConvertParameters) {
|
||||
std::vector<torch::Tensor> parameters{
|
||||
torch::arange(9, torch::kFloat32),
|
||||
torch::arange(9, torch::kFloat32).view({3, 3}),
|
||||
torch::arange(8, torch::kFloat32).view({2, 2, 2})
|
||||
};
|
||||
|
||||
auto expected = torch::cat({
|
||||
torch::arange(9, torch::kFloat32),
|
||||
torch::arange(9, torch::kFloat32).view(-1),
|
||||
torch::arange(8, torch::kFloat32).view(-1)
|
||||
});
|
||||
auto vector = utils::parameters_to_vector(parameters);
|
||||
ASSERT_TRUE(vector.allclose(expected));
|
||||
|
||||
std::vector<torch::Tensor> zero_parameters{
|
||||
torch::zeros({9}, torch::kFloat32),
|
||||
torch::zeros({9}, torch::kFloat32).view({3, 3}),
|
||||
torch::zeros({8}, torch::kFloat32).view({2, 2, 2})
|
||||
};
|
||||
|
||||
utils::vector_to_parameters(vector, zero_parameters);
|
||||
for (int i = 0; i < zero_parameters.size(); ++i) {
|
||||
ASSERT_TRUE(zero_parameters[i].allclose(parameters[i]));
|
||||
}
|
||||
|
||||
{
|
||||
auto conv1 = Conv2d(3, 10, 5);
|
||||
auto fc1 = Linear(10, 20);
|
||||
auto model = Sequential(conv1, fc1);
|
||||
|
||||
auto vec = utils::parameters_to_vector(model->parameters());
|
||||
ASSERT_EQ(vec.size(0), 980);
|
||||
}
|
||||
{
|
||||
auto conv1 = Conv2d(3, 10, 5);
|
||||
auto fc1 = Linear(10, 20);
|
||||
auto model = Sequential(conv1, fc1);
|
||||
|
||||
auto vec = torch::arange(0., 980);
|
||||
utils::vector_to_parameters(vec, model->parameters());
|
||||
|
||||
auto sample = model->parameters()[0][0][0][0];
|
||||
ASSERT_TRUE(torch::equal(sample.data(), vec.data().slice(0, 0, 5)));
|
||||
}
|
||||
}
|
||||
|
@ -129,8 +129,8 @@ torch::nn::DataParallel|No|No
|
||||
torch::nn::parallel::DistributedDataParallel|No|No
|
||||
torch::nn::utils::clip_grad_norm_|Yes|No
|
||||
torch::nn::utils::clip_grad_value_|Yes|No
|
||||
torch::nn::utils::parameters_to_vector|No|No
|
||||
torch::nn::utils::vector_to_parameters|No|No
|
||||
torch::nn::utils::parameters_to_vector|Yes|No
|
||||
torch::nn::utils::vector_to_parameters|Yes|No
|
||||
torch::nn::utils::weight_norm|No|No
|
||||
torch::nn::utils::remove_weight_norm|No|No
|
||||
torch::nn::utils::spectral_norm|No|No
|
||||
|
@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nn/utils/clip_grad.h>
|
||||
#include <torch/nn/utils/convert_parameters.h>
|
||||
|
77
torch/csrc/api/include/torch/nn/utils/convert_parameters.h
Normal file
77
torch/csrc/api/include/torch/nn/utils/convert_parameters.h
Normal file
@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
namespace utils {
|
||||
|
||||
// This helper function is to check if the parameters are located
|
||||
// in the same device. Currently, the conversion between model parameters
|
||||
// and single vector form is not supported for multiple allocations,
|
||||
// e.g. parameters in different GPUs, or mixture of CPU/GPU.
|
||||
inline c10::optional<int64_t> _check_param_device(const torch::Tensor& param, c10::optional<int64_t> old_param_device) {
|
||||
// Meet the first parameter
|
||||
if (old_param_device == c10::nullopt) {
|
||||
old_param_device = param.is_cuda() ? param.get_device() : -1;
|
||||
}
|
||||
else {
|
||||
bool warn = false;
|
||||
if (param.is_cuda()) { // Check if in same GPU
|
||||
warn = (param.get_device() != old_param_device.value());
|
||||
}
|
||||
else { // Check if in CPU
|
||||
warn = (old_param_device.value() != -1);
|
||||
}
|
||||
if (warn) {
|
||||
TORCH_CHECK(false, "Found two parameters on different devices, ",
|
||||
"this is currently not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
return old_param_device;
|
||||
}
|
||||
|
||||
// Convert parameters to one vector
|
||||
inline torch::Tensor parameters_to_vector(const std::vector<torch::Tensor>& parameters) {
|
||||
c10::optional<int64_t> param_device;
|
||||
|
||||
std::vector<torch::Tensor> vec;
|
||||
vec.reserve(parameters.size());
|
||||
|
||||
for (const torch::Tensor& param : parameters) {
|
||||
// Ensure the parameters are located in the same device
|
||||
param_device = _check_param_device(param, param_device);
|
||||
|
||||
vec.push_back(param.view(-1));
|
||||
}
|
||||
|
||||
return torch::cat(vec);
|
||||
}
|
||||
|
||||
// Convert one vector to the parameters
|
||||
inline void vector_to_parameters(const torch::Tensor& vec, std::vector<torch::Tensor> parameters) {
|
||||
// Flag for the device where the parameter is located
|
||||
c10::optional<int64_t> param_device;
|
||||
|
||||
// Pointer for slicing the vector for each parameter
|
||||
int64_t pointer = 0;
|
||||
int64_t num_param;
|
||||
for (torch::Tensor& param : parameters) {
|
||||
// Ensure the parameters are located in the same device
|
||||
param_device = _check_param_device(param, param_device);
|
||||
|
||||
// The length of the parameter
|
||||
num_param = param.numel();
|
||||
// Slice the vector, reshape it, and replace the old data of the parameter
|
||||
param.set_data(vec.slice(0, pointer, pointer + num_param).view_as(param).data());
|
||||
|
||||
// Increment the pointer
|
||||
pointer += num_param;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace nn
|
||||
} // namespace torch
|
Reference in New Issue
Block a user