Files
pytorch/torch/csrc/api/src/nn/init.cpp
cyy dee100945e [2/N] Move c10::variant to std::variant (#109723)
This PR moves most of c10::variant calls to std::variant.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109723
Approved by: https://github.com/ezyang
2023-09-24 02:47:43 +00:00

257 lines
6.5 KiB
C++

#include <torch/nn/init.h>
#include <torch/linalg.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <tuple>
namespace torch {
namespace nn {
namespace init {
namespace {
struct Fan {
explicit Fan(Tensor& tensor) {
const auto dimensions = tensor.ndimension();
TORCH_CHECK(
dimensions >= 2,
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");
if (dimensions == 2) {
in = tensor.size(1);
out = tensor.size(0);
} else {
in = tensor.size(1) * tensor[0][0].numel();
out = tensor.size(0) * tensor[0][0].numel();
}
}
int64_t in;
int64_t out;
};
double calculate_kaiming_std(
Tensor tensor,
double a,
FanModeType mode,
NonlinearityType nonlinearity) {
NoGradGuard guard;
Fan fan(tensor);
const auto gain = calculate_gain(nonlinearity, a);
double std = 0.0;
if (std::holds_alternative<enumtype::kFanIn>(mode)) {
std = gain / std::sqrt(fan.in);
} else {
std = gain / std::sqrt(fan.out);
}
return std;
}
} // namespace
double calculate_gain(NonlinearityType nonlinearity, double param) {
if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
return 5.0 / 3.0; // NOLINT
} else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
return std::sqrt(2.0); // NOLINT
} else if (std::holds_alternative<enumtype::kLeakyReLU>(nonlinearity)) {
return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
}
return 1.0;
}
Tensor constant_(Tensor tensor, Scalar value) {
NoGradGuard guard;
return tensor.fill_(value);
}
Tensor dirac_(Tensor tensor) {
NoGradGuard guard;
TORCH_CHECK(
tensor.ndimension() >= 3 && tensor.ndimension() <= 5,
"Only tensors with 3, 4, or 5 dimensions are supported");
const auto sizes = tensor.sizes();
const auto min_dim = std::min(sizes[0], sizes[1]);
tensor.zero_();
for (const auto d : c10::irange(min_dim)) {
switch (tensor.ndimension()) {
case 3: // Temporal convolution
tensor[d][d][sizes[2] / 2] = 1;
break;
case 4: // Spatial convolution
tensor[d][d][sizes[2] / 2][sizes[3] / 2] = 1;
break;
case 5: // Volumetric convolution
tensor[d][d][sizes[2] / 2][sizes[3] / 2][sizes[4] / 2] = 1;
break;
}
}
return tensor;
}
Tensor eye_(Tensor matrix) {
NoGradGuard guard;
TORCH_CHECK(
matrix.ndimension() == 2, "Only tensors with 2 dimensions are supported");
return torch::eye_out(matrix, matrix.size(0), matrix.size(1));
}
Tensor normal_(Tensor tensor, double mean, double std) {
NoGradGuard guard;
return tensor.normal_(mean, std);
}
Tensor ones_(Tensor tensor) {
NoGradGuard guard;
return tensor.fill_(1);
}
Tensor orthogonal_(Tensor tensor, double gain) {
NoGradGuard guard;
TORCH_CHECK(
tensor.ndimension() >= 2,
"Only tensors with 2 or more dimensions are supported");
const auto rows = tensor.size(0);
const auto columns = tensor.numel() / rows;
auto flattened = torch::randn({rows, columns});
if (rows < columns) {
flattened.t_();
}
// Compute the qr factorization
Tensor q, r;
std::tie(q, r) = torch::linalg::qr(flattened);
// Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
auto d = torch::diag(r, 0);
auto ph = d.sign();
q *= ph;
if (rows < columns) {
q.t_();
}
tensor.view_as(q).copy_(q);
tensor.mul_(gain);
return tensor;
}
Tensor sparse_(Tensor tensor, double sparsity, double std) {
NoGradGuard guard;
TORCH_CHECK(
tensor.ndimension() == 2, "Only tensors with 2 dimensions are supported");
const auto rows = tensor.size(0);
const auto columns = tensor.size(1);
const int64_t num_zeros = std::ceil(sparsity * rows);
tensor.normal_(0, std);
for (const auto column : c10::irange(columns)) {
auto row_indices = torch::randperm(rows, tensor.options().dtype(kLong));
auto zero_indices =
row_indices.slice(/*dim=*/0, /*start=*/0, /*end=*/num_zeros);
tensor.index_put_(
{zero_indices, torch::tensor(column, tensor.options().dtype(kLong))},
torch::zeros(num_zeros, tensor.options()));
}
return tensor;
}
Tensor uniform_(Tensor tensor, double low, double high) {
NoGradGuard guard;
return tensor.uniform_(low, high);
}
Tensor kaiming_uniform_(
Tensor tensor,
double a,
FanModeType mode,
NonlinearityType nonlinearity) {
NoGradGuard guard;
auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
// Calculate uniform bounds from standard deviation
const auto bound = std::sqrt(3.0) * std;
return tensor.uniform_(-bound, bound);
}
Tensor kaiming_normal_(
Tensor tensor,
double a,
FanModeType mode,
NonlinearityType nonlinearity) {
NoGradGuard guard;
auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
return tensor.normal_(0, std);
}
Tensor xavier_normal_(Tensor tensor, double gain) {
NoGradGuard guard;
Fan fan(tensor);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
return tensor.normal_(0, std);
}
Tensor xavier_uniform_(Tensor tensor, double gain) {
NoGradGuard guard;
Fan fan(tensor);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
// Calculate uniform bounds from standard deviation with
const auto a = std::sqrt(3.0) * std;
return tensor.uniform_(-a, a);
}
Tensor zeros_(Tensor tensor) {
NoGradGuard guard;
return tensor.zero_();
}
std::tuple<int64_t, int64_t> _calculate_fan_in_and_fan_out(
const Tensor& tensor) {
const auto dimensions = tensor.dim();
TORCH_CHECK(
dimensions >= 2,
"Fan in and fan out can not be computed "
"for tensor with fewer than 2 dimensions")
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t fan_in, fan_out;
if (dimensions == 2) { // Linear
fan_in = tensor.size(1);
fan_out = tensor.size(0);
} else {
const auto num_input_fmaps = tensor.size(1);
const auto num_output_fmaps = tensor.size(0);
auto receptive_field_size = 1;
if (tensor.dim() > 2) {
receptive_field_size = tensor[0][0].numel();
}
fan_in = num_input_fmaps * receptive_field_size;
fan_out = num_output_fmaps * receptive_field_size;
}
return std::tie(fan_in, fan_out);
}
} // namespace init
} // namespace nn
} // namespace torch