mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added out= variant for torch.linalg.lstsq (#54721)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54721 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27874711 Pulled By: mruberry fbshipit-source-id: 696ebb6eb0bad81988e9cb7a081388a3a5ab3e2c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
43c747859c
commit
3d878dee45
@ -918,8 +918,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
|
||||
// - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
|
||||
// original torch.solve supported only the matrix case, while NumPy works for both cases
|
||||
// for the batched input we need to be able to distinguish them
|
||||
auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
|
||||
bool vector_case = other.dim() == 1 || (input.dim()-1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
|
||||
bool vector_case = linalg_solve_is_vector_rhs(input, other);
|
||||
|
||||
bool is_batched_column_major = false;
|
||||
if (vector_case) {
|
||||
@ -929,7 +928,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
|
||||
}
|
||||
|
||||
// if 'other' is a batch of 2D tensors, then 'input' can be non-batched and will be broadcasted
|
||||
auto expected_shape = expected_batched_rhs_shape;
|
||||
auto expected_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
|
||||
if (!vector_case && other.dim() > 2) {
|
||||
expected_shape = other.sizes();
|
||||
}
|
||||
@ -1020,8 +1019,7 @@ Tensor& linalg_solve_out(const Tensor& input, const Tensor& other, Tensor& resul
|
||||
|
||||
// Now check LAPACK/MAGMA error codes
|
||||
// batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)'
|
||||
auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
|
||||
bool vector_case = other.dim() == 1 || (input.dim()-1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
|
||||
bool vector_case = linalg_solve_is_vector_rhs(input, other);
|
||||
if (vector_case ? result.dim() > 1 : result.dim() > 2) {
|
||||
batchCheckErrors(infos, "linalg_solve");
|
||||
} else {
|
||||
@ -2971,204 +2969,378 @@ Tensor& _lstsq_helper_cpu(
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> linalg_lstsq(
|
||||
const Tensor& self, const Tensor& b,
|
||||
c10::optional<double> cond,
|
||||
c10::optional<std::string> driver) {
|
||||
TORCH_CHECK(
|
||||
self.device().type() == b.device().type(),
|
||||
"torch.linalg.lstsq: input tensors should be on the same device"
|
||||
);
|
||||
TORCH_CHECK(
|
||||
self.scalar_type() == b.scalar_type(),
|
||||
"torch.linalg.lstsq: input tensors should be of the same dtype"
|
||||
);
|
||||
TORCH_CHECK(
|
||||
self.dim() >= 2,
|
||||
"torch.linalg.lstsq: input `self` Tensor should be at least 2D"
|
||||
);
|
||||
TORCH_CHECK(
|
||||
b.dim() >= 1,
|
||||
"torch.linalg.lstsq: input `b` Tensor should be at least 1D"
|
||||
);
|
||||
auto dim_diff = self.dim() - b.dim();
|
||||
TORCH_CHECK(
|
||||
0 <= dim_diff && dim_diff <= 1,
|
||||
"torch.linalg.lstsq: self.dim() must be greater or equal to b.dim() and "
|
||||
"(self.dim() - b.dim()) <= 1"
|
||||
);
|
||||
Tensor b_2d = dim_diff ? b.unsqueeze(-1) : b;
|
||||
TORCH_CHECK(
|
||||
self.size(-2) == b_2d.size(-2),
|
||||
dim_diff ? "torch.linalg.lstsq: self.size(-2) should match b.size(-1)" :
|
||||
"torch.linalg.lstsq: self.size(-2) should match b.size(-2)"
|
||||
);
|
||||
/*
|
||||
Solves a least squares problem. That is minimizing the squared Frobenius norm of |B - A X|.
|
||||
|
||||
// if `driver` is empty, we use `driver_opt` to be set to
|
||||
// c10::nullopt if working with CUDA tensors,
|
||||
Input args:
|
||||
* 'input' - Tensor containing batches of m-by-n matrix A.
|
||||
* 'other' - Tensor containing batches of max(m, n)-by-nrhs matrix B.
|
||||
* 'cond' - relative tolerance for determining rank of A.
|
||||
* 'driver' - the name of the LAPACK driver that is used to compute the solution.
|
||||
Output args (modified in-place):
|
||||
* 'solution' - Tensor to store the solution matrix X.
|
||||
* 'residuals' - Tensor to store values of the residual sum of squares for each column of the solution.
|
||||
* 'rank' - Tensor to store the rank of A.
|
||||
* 'singular_values' - Tensor to store the singular values of A.
|
||||
* 'infos' - Tensor to store error codes of linear algebra math library.
|
||||
|
||||
For further details, please see the LAPACK documentation for GELS/GELSY/GELSS/GELSD routines.
|
||||
*/
|
||||
static void linalg_lstsq_out_info(
|
||||
Tensor& solution,
|
||||
Tensor& residuals,
|
||||
Tensor& rank,
|
||||
Tensor& singular_values,
|
||||
Tensor& infos,
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
double rcond,
|
||||
std::string& driver) {
|
||||
// These internal asserts make explicit the assumptions in the implementation
|
||||
// Error check with the actual error messages are done on the higher level of
|
||||
// the hierarchy of calls
|
||||
TORCH_INTERNAL_ASSERT(input.dim() >= 2);
|
||||
TORCH_INTERNAL_ASSERT(other.dim() >= 1);
|
||||
|
||||
auto dim_diff = input.dim() - other.dim();
|
||||
TORCH_INTERNAL_ASSERT(0 <= dim_diff && dim_diff <= 1);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(input.scalar_type() == other.scalar_type());
|
||||
TORCH_INTERNAL_ASSERT(input.device() == other.device());
|
||||
|
||||
TORCH_INTERNAL_ASSERT(solution.scalar_type() == input.scalar_type());
|
||||
TORCH_INTERNAL_ASSERT(solution.device() == input.device());
|
||||
|
||||
TORCH_INTERNAL_ASSERT(residuals.device() == input.device());
|
||||
|
||||
TORCH_INTERNAL_ASSERT(rank.scalar_type() == at::kLong);
|
||||
TORCH_INTERNAL_ASSERT(rank.device() == input.device());
|
||||
|
||||
auto real_dtype = toValueType(input.scalar_type());
|
||||
TORCH_INTERNAL_ASSERT(singular_values.scalar_type() == real_dtype);
|
||||
TORCH_INTERNAL_ASSERT(singular_values.device() == input.device());
|
||||
|
||||
TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
|
||||
TORCH_INTERNAL_ASSERT(infos.device() == input.device());
|
||||
TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));
|
||||
TORCH_INTERNAL_ASSERT(infos.is_contiguous());
|
||||
|
||||
bool vector_case = linalg_solve_is_vector_rhs(input, other);
|
||||
// we need to unsqueeze 'other' because 2-dimensional tensors are expected in the implementation
|
||||
Tensor other_2d = vector_case ? other.unsqueeze(-1) : other;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(input.size(-2) == other_2d.size(-2));
|
||||
|
||||
std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
|
||||
// the actual shape of the solution returned is (*, n,) or (*, n, nrhs)
|
||||
// but LAPACK requires extra dimensions to store raw residuals
|
||||
// so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
|
||||
auto m = input.size(-2);
|
||||
auto n = input.size(-1);
|
||||
auto nrhs = other.size(-1);
|
||||
expected_solution_shape.push_back(std::max(m, n));
|
||||
if (!vector_case) {
|
||||
expected_solution_shape.push_back(nrhs);
|
||||
}
|
||||
|
||||
// if 'solution' has no elements we can modify it
|
||||
if (solution.numel() == 0) {
|
||||
if (vector_case) {
|
||||
solution.resize_(expected_solution_shape, MemoryFormat::Contiguous);
|
||||
} else {
|
||||
auto shape_transposed = expected_solution_shape;
|
||||
std::swap(shape_transposed.end()[-1], shape_transposed.end()[-2]);
|
||||
solution.resize_(shape_transposed, MemoryFormat::Contiguous);
|
||||
solution.transpose_(-2, -1);
|
||||
}
|
||||
}
|
||||
|
||||
// if 'solution' is non-empty it must have the expected shape
|
||||
TORCH_INTERNAL_ASSERT(solution.sizes().equals(expected_solution_shape));
|
||||
|
||||
// 'solution' must be in batched column major order (Fortran contiguous) for 2D inputs
|
||||
// or C contiguous for 1D input
|
||||
if (vector_case) {
|
||||
TORCH_INTERNAL_ASSERT(solution.is_contiguous());
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(solution.transpose(-2, -1).is_contiguous());
|
||||
}
|
||||
|
||||
// for 1-dimensional 'other', we need to unsqueeze the 'solution' before passing to "apply_solve"
|
||||
if (vector_case) {
|
||||
solution = solution.unsqueeze_(-1);
|
||||
}
|
||||
|
||||
// _linalg_lstsq_helper_ performs calculations in-place and 'solution' must be a copy of other_2d
|
||||
solution.narrow(-2, 0, other_2d.size(-2)).copy_(other_2d);
|
||||
|
||||
// if 'rank' is empty we might resize it
|
||||
auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
|
||||
if (rank.numel() == 0 && driver != "gels") { // gels driver doesn't set 'rank'
|
||||
rank.resize_(input_batch_shape, MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// if 'rank' is non-empty it must have the expected shape and be contiguous
|
||||
if (driver != "gels") {
|
||||
TORCH_INTERNAL_ASSERT(rank.sizes().equals(input_batch_shape));
|
||||
TORCH_INTERNAL_ASSERT(rank.is_contiguous());
|
||||
}
|
||||
|
||||
// if 'singular_values' is empty we might resize it
|
||||
auto singular_values_shape = input_batch_shape.vec();
|
||||
singular_values_shape.push_back(std::min(m, n));
|
||||
if (singular_values.numel() == 0 && (driver == "gelsd" || driver == "gelss")) {
|
||||
singular_values.resize_(singular_values_shape, MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// if 'singular_values' is non-empty it must have the expected shape and be contiguous
|
||||
if (driver == "gelsd" || driver == "gelss") {
|
||||
TORCH_INTERNAL_ASSERT(singular_values.sizes().equals(singular_values_shape));
|
||||
TORCH_INTERNAL_ASSERT(singular_values.is_contiguous());
|
||||
}
|
||||
|
||||
// 'input' is modified in-place so we need a column-major copy
|
||||
auto input_working_copy = copyBatchedColumnMajor(input);
|
||||
|
||||
// now the actual call that computes the result in-place (apply_lstsq)
|
||||
at::_lstsq_helper_(solution, rank, singular_values, infos, input_working_copy, rcond, driver);
|
||||
|
||||
if (m > n && driver != "gelsy") {
|
||||
// LAPACK stores residuals data for postprocessing in rows n:(m-n)
|
||||
auto raw_residuals = solution.narrow(/*dim=*/-2, /*start=*/n, /*length*/m - n);
|
||||
if (raw_residuals.is_complex()) {
|
||||
raw_residuals.mul_(raw_residuals.conj());
|
||||
raw_residuals = at::real(raw_residuals);
|
||||
} else {
|
||||
raw_residuals.pow_(2);
|
||||
}
|
||||
at::sum_out(residuals, raw_residuals, /*dim=*/-2, /*keepdim=*/false, /*dtype*/real_dtype);
|
||||
}
|
||||
solution = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n);
|
||||
if (m == 0) {
|
||||
solution.zero_();
|
||||
}
|
||||
|
||||
// for 1-dimensional 'other', we need to squeeze the solution after "apply_lstsq"
|
||||
if (vector_case) {
|
||||
solution = solution.squeeze_(-1);
|
||||
}
|
||||
}
|
||||
|
||||
static std::string get_default_lstsq_driver(c10::optional<std::string> driver, const Tensor& input) {
|
||||
// if `driver` is empty, we set driver_str to "gels" if working with CUDA tensors,
|
||||
// otherwise to "gelsy" driver.
|
||||
// CUDA tensors are treated specially because MAGMA
|
||||
// has only 'gels' driver supported.
|
||||
c10::optional<std::string> driver_opt = driver;
|
||||
std::string driver_str;
|
||||
// check whether the user provided name is a valid driver name
|
||||
if (driver.has_value()) {
|
||||
auto driver_str = driver.value();
|
||||
driver_str = driver.value();
|
||||
// convert `driver_str` to lower case inplace.
|
||||
std::transform(driver_str.begin(), driver_str.end(), driver_str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
static std::unordered_set<std::string> allowed_drivers = {
|
||||
"gels", "gelsy", "gelsd", "gelss"
|
||||
};
|
||||
if (at::kCPU == self.device().type()) {
|
||||
if (input.device() == at::kCPU) {
|
||||
TORCH_CHECK(
|
||||
allowed_drivers.find(driver_str) != allowed_drivers.end(),
|
||||
"torch.linalg.lstsq: parameter `driver` should be one of "
|
||||
"(gels, gelsy, gelsd, gelss)"
|
||||
);
|
||||
}
|
||||
//else if (at::kCUDA == self.device().type()) {
|
||||
else {
|
||||
} else { // else if (input.is_cuda())
|
||||
TORCH_CHECK(
|
||||
driver_str == "gels",
|
||||
"torch.linalg.lstsq: `driver` other than `gels` is not supported on CUDA"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// if driver name is not provided, set to default 'gelsy' if on CPU,
|
||||
// or to `gels` if on CUDA.
|
||||
driver_str = input.is_cuda() ? "gels" : "gelsy";
|
||||
}
|
||||
// if driver name is not provided, set to default 'gelsy' if on CPU,
|
||||
// or to `gels` if on CUDA.
|
||||
else {
|
||||
driver_opt = (at::kCPU == self.device().type())
|
||||
? c10::optional<std::string>("gelsy")
|
||||
: c10::optional<std::string>("gels");
|
||||
return driver_str;
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
c10::optional<double> rcond,
|
||||
c10::optional<std::string> driver,
|
||||
Tensor& solution,
|
||||
Tensor& residuals,
|
||||
Tensor& rank,
|
||||
Tensor& singular_values) {
|
||||
TORCH_CHECK(input.dim() >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions.");
|
||||
TORCH_CHECK(other.dim() >= 1, "torch.linalg.lstsq: other must have at least 1 dimension.");
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == other.scalar_type(),
|
||||
"torch.linalg.lstsq: Expected input and other to have the same dtype, but got input's dtype ",
|
||||
input.scalar_type(),
|
||||
" and other's dtype ",
|
||||
other.scalar_type());
|
||||
|
||||
auto dim_diff = input.dim() - other.dim();
|
||||
TORCH_CHECK(
|
||||
0 <= dim_diff && dim_diff <= 1,
|
||||
"torch.linalg.lstsq: input.dim() must be greater or equal to other.dim() and (input.dim() - other.dim()) <= 1");
|
||||
Tensor other_2d = dim_diff ? other.unsqueeze(-1) : other;
|
||||
TORCH_CHECK(
|
||||
input.size(-2) == other_2d.size(-2),
|
||||
dim_diff ? "torch.linalg.lstsq: input.size(-2) should match other.size(-1)"
|
||||
: "torch.linalg.lstsq: input.size(-2) should match other.size(-2)");
|
||||
|
||||
checkSameDevice("torch.linalg.lstsq", other, input, "other");
|
||||
checkSameDevice("torch.linalg.lstsq", solution, input, "solution");
|
||||
checkSameDevice("torch.linalg.lstsq", residuals, input, "residuals");
|
||||
checkSameDevice("torch.linalg.lstsq", rank, input, "rank");
|
||||
checkSameDevice("torch.linalg.lstsq", singular_values, input, "singular_values");
|
||||
|
||||
// 'solution' is expected to have same dtype as input
|
||||
checkLinalgCompatibleDtype("torch.linalg.lstsq", solution, input, "solution");
|
||||
|
||||
// 'residuals' is expected to have real float dtype
|
||||
ScalarType real_dtype = c10::toValueType(input.scalar_type());
|
||||
checkLinalgCompatibleDtype("torch.linalg.lstsq", residuals.scalar_type(), real_dtype, "solution");
|
||||
|
||||
// 'rank' is expected to have integer dtype
|
||||
// actual LAPACK calls use int32_t type for rank, but we promote it to int64_t
|
||||
// to be consistent with torch.linalg.matrix_rank output dtype
|
||||
ScalarType rank_expected_type = ScalarType::Long;
|
||||
checkLinalgCompatibleDtype("torch.linalg.lstsq", rank.scalar_type(), rank_expected_type, "rank");
|
||||
|
||||
// 'singular_values' is expected to have real float dtype
|
||||
checkLinalgCompatibleDtype("torch.linalg.lstsq", singular_values.scalar_type(), real_dtype, "singular_values");
|
||||
|
||||
std::string driver_name = get_default_lstsq_driver(driver, input);
|
||||
|
||||
// set default rcond value
|
||||
// TODO: Change this to match non-legacy NumPy behaviour
|
||||
double rcond_value = rcond.has_value() && (rcond.value() > 0)
|
||||
? rcond.value()
|
||||
: _get_epsilon(c10::toValueType(input.scalar_type()));
|
||||
|
||||
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));
|
||||
|
||||
// now check whether the provided output tensors can be used directly
|
||||
|
||||
// Two types of 'other' tensors are supported:
|
||||
// - 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
|
||||
// - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
|
||||
// original torch.lstsq supported only the matrix case, while NumPy works for both cases
|
||||
// for the batched input we need to be able to distinguish them
|
||||
// auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
|
||||
// bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
|
||||
bool vector_case = linalg_solve_is_vector_rhs(input, other);
|
||||
|
||||
// provided output tensor can be used directly if:
|
||||
// 1. the shape matches the expected shape
|
||||
// 2. the dtype matches the expected dtype
|
||||
// 3. the tensor is contiguous
|
||||
|
||||
// Checks for the 'solution' tensor
|
||||
std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
|
||||
// the actual shape of the shape of the solution returned in (*, n,) or (*, n, nrhs)
|
||||
// but LAPACK requires extra dimensions so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
|
||||
expected_solution_shape.push_back(std::max(input.size(-1), input.size(-2)));
|
||||
if (!vector_case && other.dim() > 2) {
|
||||
expected_solution_shape.push_back(other.size(-1));
|
||||
}
|
||||
|
||||
// CUDA has only `gels` driver now which ONLY works with overdetermined systems
|
||||
if (at::kCUDA == self.device().type()) {
|
||||
TORCH_CHECK(
|
||||
self.size(-2) >= self.size(-1),
|
||||
"torch.linalg.lstsq: only overdetermined systems (m >= n) are allowed on CUDA"
|
||||
);
|
||||
bool solution_equal_expected_shape = solution.sizes().equals(expected_solution_shape);
|
||||
bool solution_input_same_type = (solution.scalar_type() == input.scalar_type());
|
||||
|
||||
bool is_solution_batched_column_major = false;
|
||||
if (vector_case) {
|
||||
is_solution_batched_column_major = solution.is_contiguous();
|
||||
} else if (!vector_case && solution.dim() >= 2) {
|
||||
is_solution_batched_column_major = solution.transpose(-2, -1).is_contiguous();
|
||||
}
|
||||
|
||||
// LAPACK/MAGMA requries inputs to be in the column-major-order.
|
||||
auto self_working_copy = copyBatchedColumnMajor(self);
|
||||
// 'residuals' is not checked here because at::sum_out(residuals, ...) does that
|
||||
|
||||
// Tensor b must be of size (..., max(m, n), nrhs)
|
||||
// and in the column-major order.
|
||||
// We allow the batch dims of `self` to broadcast over the batch
|
||||
// dims of `b` so that it is possible to solve multiple systems with
|
||||
// with the same lhs (encoded by `self`) / rhs (encoded by `b`).
|
||||
// `b_working_copy` is modified in-place and the combination of
|
||||
// batch broadcasting plus LAPACK/MAGMA requirements impose the following
|
||||
// restrictions on sizes/strides of `b`:
|
||||
// 1. b.size = (broadcasted_batch_size(self, b), max(m, n), nrhs).
|
||||
// 2. b.stride should correspond to an almost contiguous Tensor in the column-major-order,
|
||||
// i.e. b.stride = b.transpose(-2, -1).contiguous().transpose(-2, -1).strides()
|
||||
auto m = self.size(-2);
|
||||
auto n = self.size(-1);
|
||||
auto b_working_copy = copyBatchedColumnMajor(b_2d,
|
||||
/*nrows=*/std::max(m, n),
|
||||
/*desired_batch_sizes=*/broadcast_batch_size(self, b_2d, self.dim() - 2));
|
||||
auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
|
||||
|
||||
double rcond = cond.has_value() && (cond.value() > 0)
|
||||
? cond.value()
|
||||
: _get_epsilon(c10::toValueType(self.scalar_type()));
|
||||
|
||||
auto batch_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend() - 2);
|
||||
Tensor rank = at::empty({0}, self.options().dtype(at::kLong));
|
||||
if (driver_opt.value() != "gels") {
|
||||
rank.resize_(batch_shape, MemoryFormat::Contiguous);
|
||||
// Checks for the 'rank' tensor
|
||||
// rank is a scalar value for each matrix in the batch so
|
||||
// rank's expected shape is equal to input.shape[0:input.ndim-2]
|
||||
bool rank_equal_expected_shape = true;
|
||||
bool rank_equal_expected_type = true;
|
||||
bool rank_is_contiguous = true;
|
||||
if (driver_name != "gels") { // gels driver doesn't set 'rank'
|
||||
rank_equal_expected_shape = rank.sizes().equals(input_batch_shape);
|
||||
rank_equal_expected_type = (rank.scalar_type() == at::kLong);
|
||||
rank_is_contiguous = rank.is_contiguous();
|
||||
}
|
||||
|
||||
auto singular_values_shape = batch_shape.vec();
|
||||
singular_values_shape.push_back(std::min(m, n));
|
||||
auto real_dtype = c10::toValueType(self.scalar_type());
|
||||
Tensor singular_values = at::empty({0}, self.options().dtype(real_dtype));
|
||||
if (driver_opt.value() == "gelsd" || driver_opt.value() == "gelss") {
|
||||
singular_values.resize_(singular_values_shape, MemoryFormat::Contiguous);
|
||||
// Checks for the 'singular_values' tensor
|
||||
// singular values are computed only with "gelsd" and "gelss" drivers currently
|
||||
bool singular_values_equal_expected_shape = true;
|
||||
bool singular_values_equal_expected_type = true;
|
||||
bool singular_values_is_contiguous = true;
|
||||
if (driver_name == "gelsd" || driver_name == "gelss") {
|
||||
auto singular_values_shape = input_batch_shape.vec();
|
||||
singular_values_shape.push_back(std::min(input.size(-1), input.size(-2)));
|
||||
singular_values_equal_expected_shape = singular_values.sizes().equals(singular_values_shape);
|
||||
singular_values_equal_expected_type = (singular_values.scalar_type() == real_dtype);
|
||||
singular_values_is_contiguous = singular_values.is_contiguous();
|
||||
}
|
||||
|
||||
Tensor infos = at::zeros({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt).device(kCPU));
|
||||
// if solution is not empty and not in batched column major format
|
||||
bool copy_needed = (solution.numel() != 0 && !is_solution_batched_column_major);
|
||||
copy_needed |= !solution_input_same_type; // or solution does not have the same dtype as input
|
||||
copy_needed |= (solution.numel() != 0 && !solution_equal_expected_shape); // or solution does not have the expected shape
|
||||
|
||||
Tensor x, residuals;
|
||||
copy_needed |= !rank_equal_expected_type;
|
||||
copy_needed |= (rank.numel() != 0 && !rank_equal_expected_shape);
|
||||
copy_needed |= (rank.numel() != 0 && !rank_is_contiguous);
|
||||
|
||||
// path if neither `self` nor `b` is empty
|
||||
if (self.numel() && b.numel()) {
|
||||
x = at::_lstsq_helper_(b_working_copy, rank, singular_values, infos, self_working_copy, rcond, driver_opt.value());
|
||||
if (m > n && driver_opt.value() != "gelsy") {
|
||||
residuals = x.narrow(-2, n, std::max(m, n) - n).abs().pow_(2).sum(-2);
|
||||
}
|
||||
x = x.narrow(-2, 0, n);
|
||||
}
|
||||
// if either `self` or `b` is empty, return an empty tensor or,
|
||||
// if non-zero sizes, return a tensor of zeros.
|
||||
else {
|
||||
x = b_working_copy.zero_().narrow(-2, 0, n);
|
||||
copy_needed |= !singular_values_equal_expected_type;
|
||||
copy_needed |= (singular_values.numel() != 0 && !singular_values_equal_expected_shape);
|
||||
copy_needed |= (singular_values.numel() != 0 && !singular_values_is_contiguous);
|
||||
|
||||
if (copy_needed) { // we have to allocate temporary tensors
|
||||
Tensor solution_tmp = at::empty({0}, input.options());
|
||||
Tensor residuals_tmp = at::empty({0}, input.options().dtype(real_dtype));
|
||||
Tensor rank_tmp = at::empty({0}, input.options().dtype(at::kLong));
|
||||
Tensor singular_values_tmp = at::empty({0}, input.options().dtype(real_dtype));
|
||||
|
||||
linalg_lstsq_out_info(solution_tmp, residuals_tmp, rank_tmp, singular_values_tmp, infos, input, other, rcond_value, driver_name);
|
||||
|
||||
at::native::resize_output(solution, solution_tmp.sizes());
|
||||
solution.copy_(solution_tmp);
|
||||
|
||||
at::native::resize_output(residuals, residuals_tmp.sizes());
|
||||
residuals.copy_(residuals_tmp);
|
||||
|
||||
at::native::resize_output(rank, rank_tmp.sizes());
|
||||
rank.copy_(rank_tmp);
|
||||
|
||||
at::native::resize_output(singular_values, singular_values_tmp.sizes());
|
||||
singular_values.copy_(singular_values_tmp);
|
||||
} else {
|
||||
// else use the provided output storage directly
|
||||
linalg_lstsq_out_info(solution, residuals, rank, singular_values, infos, input, other, rcond_value, driver_name);
|
||||
}
|
||||
|
||||
auto return_empty_if_undefined = [&self](Tensor& t,
|
||||
c10::optional<at::ScalarType> dtype = c10::nullopt,
|
||||
c10::optional<std::vector<int64_t>> shape = c10::nullopt) {
|
||||
if (t.defined()) {
|
||||
return t;
|
||||
}
|
||||
else {
|
||||
auto output_dtype = dtype.has_value() ? dtype.value() : self.scalar_type();
|
||||
if (shape.has_value()) {
|
||||
return at::empty(shape.value(), self.options().dtype(output_dtype));
|
||||
}
|
||||
else {
|
||||
return at::empty({0}, self.options().dtype(output_dtype));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Some output stays undefined for some values of driver.
|
||||
// Instead of returning undefined tensors which get exposed as
|
||||
// Nones in the Python interface, we return empty tensors.
|
||||
// This way we follow the convention of output types in the
|
||||
// torch.linalg namespace.
|
||||
// NOTE: we run drivers only if both inputs are non-empty!
|
||||
// Hence the code below explicitly handles each and every output
|
||||
// if `self` is empty.
|
||||
|
||||
// Numpy and Scipy always return ranks for empty matrices,
|
||||
// even for drivers which are not rank-revealing.
|
||||
if (self.numel()) {
|
||||
rank = return_empty_if_undefined(rank, at::kLong);
|
||||
}
|
||||
else {
|
||||
rank = at::zeros(batch_shape, self.options().dtype(at::kLong));
|
||||
}
|
||||
|
||||
// undefined residuals could only be an empty Tensor of shape (0)
|
||||
residuals = return_empty_if_undefined(residuals);
|
||||
|
||||
if (!self.numel()
|
||||
&& (driver_opt.value() == "gelss" || driver_opt.value() == "gelsd")) {
|
||||
// when `self` is empty, return singular_values of shape
|
||||
// (*self.shape[:-2], 0) only if driver is in ('gelss', 'gelsd')
|
||||
auto singular_values_empty_shape = batch_shape.vec();
|
||||
singular_values_empty_shape.push_back(0);
|
||||
singular_values = return_empty_if_undefined(
|
||||
singular_values,
|
||||
at::toValueType(self.scalar_type()),
|
||||
singular_values_empty_shape);
|
||||
}
|
||||
else {
|
||||
// otherwise return an empty tensor of shape (0)
|
||||
singular_values = return_empty_if_undefined(
|
||||
singular_values,
|
||||
at::toValueType(self.scalar_type()));
|
||||
}
|
||||
|
||||
if (self.dim() > 2) {
|
||||
if (infos.numel() > 1) {
|
||||
batchCheckErrors(infos, "torch.linalg.lstsq");
|
||||
} else {
|
||||
singleCheckErrors(infos.item().toInt(), "torch.linalg.lstsq");
|
||||
singleCheckErrors(infos.item<int64_t>(), "torch.linalg.lstsq");
|
||||
}
|
||||
|
||||
return std::make_tuple(x, residuals, rank, singular_values);
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(solution, residuals, rank, singular_values);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> linalg_lstsq(
|
||||
const Tensor& input, const Tensor& other,
|
||||
c10::optional<double> rcond,
|
||||
c10::optional<std::string> driver) {
|
||||
Tensor solution = at::empty({0}, input.options());
|
||||
Tensor residuals = at::empty({0}, input.options().dtype(toValueType(input.scalar_type())));
|
||||
Tensor rank = at::empty({0}, input.options().dtype(at::kLong));
|
||||
Tensor singular_values = at::empty({0}, input.options().dtype(toValueType(input.scalar_type())));
|
||||
std::tie(solution, residuals, rank, singular_values) =
|
||||
at::linalg_lstsq_outf(input, other, rcond, driver, solution, residuals, rank, singular_values);
|
||||
return std::make_tuple(solution, residuals, rank, singular_values);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -514,4 +514,20 @@ static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Scalar
|
||||
out_name, " with dtype ", out_type);
|
||||
}
|
||||
|
||||
/*
|
||||
Two types of 'other' tensors are supported when solving
|
||||
a system of linear equations matmul(input, x) = other:
|
||||
* 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
|
||||
* 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
|
||||
The original torch.solve supported only the matrix case, while NumPy works for both cases.
|
||||
For the batched input we need to be able to distinguish them.
|
||||
Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
|
||||
This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
|
||||
*/
|
||||
static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
|
||||
auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
|
||||
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
|
||||
return vector_case;
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -2667,6 +2667,11 @@ TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
|
||||
auto trans = MagmaNoTrans;
|
||||
auto m = magma_int_cast(a.size(-2), "m");
|
||||
auto n = magma_int_cast(a.size(-1), "n");
|
||||
|
||||
TORCH_CHECK(
|
||||
m >= n,
|
||||
"torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA");
|
||||
|
||||
auto nrhs = magma_int_cast(b.size(-1), "nrhs");
|
||||
auto ldda = std::max<magma_int_t>(1, m);
|
||||
auto lddb = std::max<magma_int_t>(1, std::max(m, n));
|
||||
|
@ -8637,6 +8637,12 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: linalg_lstsq
|
||||
|
||||
- func: linalg_lstsq.out(Tensor self, Tensor b, float? cond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_lstsq_out
|
||||
|
||||
- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float cond, str driver_name) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
|
@ -132,7 +132,7 @@ class TestLinalg(TestCase):
|
||||
sol2 = a.pinverse() @ b
|
||||
self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def check_correctness_ref(a, b, res, ref):
|
||||
def check_correctness_ref(a, b, res, ref, driver="default"):
|
||||
def apply_if_not_empty(t, f):
|
||||
if t.numel():
|
||||
return f(t)
|
||||
@ -157,18 +157,34 @@ class TestLinalg(TestCase):
|
||||
rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1))
|
||||
singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1])
|
||||
|
||||
for i in range(batch_size):
|
||||
sol, residuals, rank, singular_values = ref(
|
||||
a_3d.select(0, i).numpy(),
|
||||
b_3d.select(0, i).numpy()
|
||||
)
|
||||
# Singular values are None when lapack_driver='gelsy' in SciPy
|
||||
if singular_values is None:
|
||||
singular_values = []
|
||||
self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
|
||||
self.assertEqual(residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5)
|
||||
self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
|
||||
self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
|
||||
if a.numel() > 0:
|
||||
for i in range(batch_size):
|
||||
sol, residuals, rank, singular_values = ref(
|
||||
a_3d.select(0, i).numpy(),
|
||||
b_3d.select(0, i).numpy()
|
||||
)
|
||||
# Singular values are None when lapack_driver='gelsy' in SciPy
|
||||
if singular_values is None:
|
||||
singular_values = []
|
||||
self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
|
||||
self.assertEqual(residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5)
|
||||
self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
|
||||
self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
|
||||
else:
|
||||
self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs))
|
||||
self.assertEqual(res.rank.shape, a.shape[:-2])
|
||||
|
||||
# residuals are not always computed (and have non-zero shape)
|
||||
if m > n and driver != "gelsy":
|
||||
self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0))
|
||||
else:
|
||||
self.assertEqual(res.residuals.shape, (0, ))
|
||||
|
||||
# singular_values are not always computed (and have non-zero shape)
|
||||
if driver == "default" or driver == "gelsd" or driver == "gelss":
|
||||
self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n)))
|
||||
else:
|
||||
self.assertEqual(res.singular_values.shape, (0, ))
|
||||
|
||||
def check_correctness_scipy(a, b, res, driver, cond):
|
||||
if TEST_SCIPY and driver not in (None, 'gels'):
|
||||
@ -176,7 +192,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
def scipy_ref(a, b):
|
||||
return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
|
||||
check_correctness_ref(a, b, res, scipy_ref)
|
||||
check_correctness_ref(a, b, res, scipy_ref, driver=driver)
|
||||
|
||||
def check_correctness_numpy(a, b, res, driver, cond):
|
||||
if driver in ('gelsd', 'gelss'):
|
||||
@ -317,16 +333,16 @@ class TestLinalg(TestCase):
|
||||
a = torch.rand(2, 3, dtype=dtype, device=device)
|
||||
b = torch.rand(3, dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'input `self` Tensor should be at least 2D'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'):
|
||||
torch.linalg.lstsq(b, b)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'input `b` Tensor should be at least 1D'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'):
|
||||
torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'self.size\(-2\) should match b.size\(-1\)'):
|
||||
with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'):
|
||||
torch.linalg.lstsq(a, b)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'self.size\(-2\) should match b.size\(-2\)'):
|
||||
with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
|
||||
torch.linalg.lstsq(a, b.unsqueeze(-1))
|
||||
|
||||
def complement_device(device):
|
||||
@ -338,11 +354,11 @@ class TestLinalg(TestCase):
|
||||
a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
|
||||
b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device))
|
||||
if a.device != b.device:
|
||||
with self.assertRaisesRegex(RuntimeError, 'input tensors should be on the same device'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'be on the same device'):
|
||||
torch.linalg.lstsq(a, b)
|
||||
|
||||
b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long()
|
||||
with self.assertRaisesRegex(RuntimeError, 'input tensors should be of the same dtype'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'the same dtype'):
|
||||
torch.linalg.lstsq(a, b)
|
||||
|
||||
a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
|
||||
@ -359,7 +375,7 @@ class TestLinalg(TestCase):
|
||||
if device != 'cpu':
|
||||
a = torch.rand(2, 3, dtype=dtype, device=device)
|
||||
b = torch.rand(2, 1, dtype=dtype, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems \(m >= n\) are allowed on CUDA'):
|
||||
with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'):
|
||||
torch.linalg.lstsq(a, b)
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
|
@ -3712,7 +3712,7 @@ op_db: List[OpInfo] = [
|
||||
aten_name='linalg_lstsq',
|
||||
op=torch.linalg.lstsq,
|
||||
dtypes=floating_and_complex_types(),
|
||||
supports_out=False,
|
||||
supports_out=True,
|
||||
sample_inputs_func=sample_inputs_linalg_lstsq,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
|
Reference in New Issue
Block a user