Added out= variant for torch.linalg.lstsq (#54721)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54721

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27874711

Pulled By: mruberry

fbshipit-source-id: 696ebb6eb0bad81988e9cb7a081388a3a5ab3e2c
This commit is contained in:
Ivan Yashchuk
2021-04-20 07:07:37 -07:00
committed by Facebook GitHub Bot
parent 43c747859c
commit 3d878dee45
6 changed files with 401 additions and 186 deletions

View File

@ -918,8 +918,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
// - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
// original torch.solve supported only the matrix case, while NumPy works for both cases
// for the batched input we need to be able to distinguish them
auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
bool vector_case = other.dim() == 1 || (input.dim()-1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
bool vector_case = linalg_solve_is_vector_rhs(input, other);
bool is_batched_column_major = false;
if (vector_case) {
@ -929,7 +928,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
}
// if 'other' is a batch of 2D tensors, then 'input' can be non-batched and will be broadcasted
auto expected_shape = expected_batched_rhs_shape;
auto expected_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
if (!vector_case && other.dim() > 2) {
expected_shape = other.sizes();
}
@ -1020,8 +1019,7 @@ Tensor& linalg_solve_out(const Tensor& input, const Tensor& other, Tensor& resul
// Now check LAPACK/MAGMA error codes
// batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)'
auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
bool vector_case = other.dim() == 1 || (input.dim()-1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
bool vector_case = linalg_solve_is_vector_rhs(input, other);
if (vector_case ? result.dim() > 1 : result.dim() > 2) {
batchCheckErrors(infos, "linalg_solve");
} else {
@ -2971,204 +2969,378 @@ Tensor& _lstsq_helper_cpu(
#endif
}
std::tuple<Tensor, Tensor, Tensor, Tensor> linalg_lstsq(
const Tensor& self, const Tensor& b,
c10::optional<double> cond,
c10::optional<std::string> driver) {
TORCH_CHECK(
self.device().type() == b.device().type(),
"torch.linalg.lstsq: input tensors should be on the same device"
);
TORCH_CHECK(
self.scalar_type() == b.scalar_type(),
"torch.linalg.lstsq: input tensors should be of the same dtype"
);
TORCH_CHECK(
self.dim() >= 2,
"torch.linalg.lstsq: input `self` Tensor should be at least 2D"
);
TORCH_CHECK(
b.dim() >= 1,
"torch.linalg.lstsq: input `b` Tensor should be at least 1D"
);
auto dim_diff = self.dim() - b.dim();
TORCH_CHECK(
0 <= dim_diff && dim_diff <= 1,
"torch.linalg.lstsq: self.dim() must be greater or equal to b.dim() and "
"(self.dim() - b.dim()) <= 1"
);
Tensor b_2d = dim_diff ? b.unsqueeze(-1) : b;
TORCH_CHECK(
self.size(-2) == b_2d.size(-2),
dim_diff ? "torch.linalg.lstsq: self.size(-2) should match b.size(-1)" :
"torch.linalg.lstsq: self.size(-2) should match b.size(-2)"
);
/*
Solves a least squares problem. That is minimizing the squared Frobenius norm of |B - A X|.
// if `driver` is empty, we use `driver_opt` to be set to
// c10::nullopt if working with CUDA tensors,
Input args:
* 'input' - Tensor containing batches of m-by-n matrix A.
* 'other' - Tensor containing batches of max(m, n)-by-nrhs matrix B.
* 'cond' - relative tolerance for determining rank of A.
* 'driver' - the name of the LAPACK driver that is used to compute the solution.
Output args (modified in-place):
* 'solution' - Tensor to store the solution matrix X.
* 'residuals' - Tensor to store values of the residual sum of squares for each column of the solution.
* 'rank' - Tensor to store the rank of A.
* 'singular_values' - Tensor to store the singular values of A.
* 'infos' - Tensor to store error codes of linear algebra math library.
For further details, please see the LAPACK documentation for GELS/GELSY/GELSS/GELSD routines.
*/
static void linalg_lstsq_out_info(
Tensor& solution,
Tensor& residuals,
Tensor& rank,
Tensor& singular_values,
Tensor& infos,
const Tensor& input,
const Tensor& other,
double rcond,
std::string& driver) {
// These internal asserts make explicit the assumptions in the implementation
// Error check with the actual error messages are done on the higher level of
// the hierarchy of calls
TORCH_INTERNAL_ASSERT(input.dim() >= 2);
TORCH_INTERNAL_ASSERT(other.dim() >= 1);
auto dim_diff = input.dim() - other.dim();
TORCH_INTERNAL_ASSERT(0 <= dim_diff && dim_diff <= 1);
TORCH_INTERNAL_ASSERT(input.scalar_type() == other.scalar_type());
TORCH_INTERNAL_ASSERT(input.device() == other.device());
TORCH_INTERNAL_ASSERT(solution.scalar_type() == input.scalar_type());
TORCH_INTERNAL_ASSERT(solution.device() == input.device());
TORCH_INTERNAL_ASSERT(residuals.device() == input.device());
TORCH_INTERNAL_ASSERT(rank.scalar_type() == at::kLong);
TORCH_INTERNAL_ASSERT(rank.device() == input.device());
auto real_dtype = toValueType(input.scalar_type());
TORCH_INTERNAL_ASSERT(singular_values.scalar_type() == real_dtype);
TORCH_INTERNAL_ASSERT(singular_values.device() == input.device());
TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
TORCH_INTERNAL_ASSERT(infos.device() == input.device());
TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));
TORCH_INTERNAL_ASSERT(infos.is_contiguous());
bool vector_case = linalg_solve_is_vector_rhs(input, other);
// we need to unsqueeze 'other' because 2-dimensional tensors are expected in the implementation
Tensor other_2d = vector_case ? other.unsqueeze(-1) : other;
TORCH_INTERNAL_ASSERT(input.size(-2) == other_2d.size(-2));
std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
// the actual shape of the solution returned is (*, n,) or (*, n, nrhs)
// but LAPACK requires extra dimensions to store raw residuals
// so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
auto m = input.size(-2);
auto n = input.size(-1);
auto nrhs = other.size(-1);
expected_solution_shape.push_back(std::max(m, n));
if (!vector_case) {
expected_solution_shape.push_back(nrhs);
}
// if 'solution' has no elements we can modify it
if (solution.numel() == 0) {
if (vector_case) {
solution.resize_(expected_solution_shape, MemoryFormat::Contiguous);
} else {
auto shape_transposed = expected_solution_shape;
std::swap(shape_transposed.end()[-1], shape_transposed.end()[-2]);
solution.resize_(shape_transposed, MemoryFormat::Contiguous);
solution.transpose_(-2, -1);
}
}
// if 'solution' is non-empty it must have the expected shape
TORCH_INTERNAL_ASSERT(solution.sizes().equals(expected_solution_shape));
// 'solution' must be in batched column major order (Fortran contiguous) for 2D inputs
// or C contiguous for 1D input
if (vector_case) {
TORCH_INTERNAL_ASSERT(solution.is_contiguous());
} else {
TORCH_INTERNAL_ASSERT(solution.transpose(-2, -1).is_contiguous());
}
// for 1-dimensional 'other', we need to unsqueeze the 'solution' before passing to "apply_solve"
if (vector_case) {
solution = solution.unsqueeze_(-1);
}
// _linalg_lstsq_helper_ performs calculations in-place and 'solution' must be a copy of other_2d
solution.narrow(-2, 0, other_2d.size(-2)).copy_(other_2d);
// if 'rank' is empty we might resize it
auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
if (rank.numel() == 0 && driver != "gels") { // gels driver doesn't set 'rank'
rank.resize_(input_batch_shape, MemoryFormat::Contiguous);
}
// if 'rank' is non-empty it must have the expected shape and be contiguous
if (driver != "gels") {
TORCH_INTERNAL_ASSERT(rank.sizes().equals(input_batch_shape));
TORCH_INTERNAL_ASSERT(rank.is_contiguous());
}
// if 'singular_values' is empty we might resize it
auto singular_values_shape = input_batch_shape.vec();
singular_values_shape.push_back(std::min(m, n));
if (singular_values.numel() == 0 && (driver == "gelsd" || driver == "gelss")) {
singular_values.resize_(singular_values_shape, MemoryFormat::Contiguous);
}
// if 'singular_values' is non-empty it must have the expected shape and be contiguous
if (driver == "gelsd" || driver == "gelss") {
TORCH_INTERNAL_ASSERT(singular_values.sizes().equals(singular_values_shape));
TORCH_INTERNAL_ASSERT(singular_values.is_contiguous());
}
// 'input' is modified in-place so we need a column-major copy
auto input_working_copy = copyBatchedColumnMajor(input);
// now the actual call that computes the result in-place (apply_lstsq)
at::_lstsq_helper_(solution, rank, singular_values, infos, input_working_copy, rcond, driver);
if (m > n && driver != "gelsy") {
// LAPACK stores residuals data for postprocessing in rows n:(m-n)
auto raw_residuals = solution.narrow(/*dim=*/-2, /*start=*/n, /*length*/m - n);
if (raw_residuals.is_complex()) {
raw_residuals.mul_(raw_residuals.conj());
raw_residuals = at::real(raw_residuals);
} else {
raw_residuals.pow_(2);
}
at::sum_out(residuals, raw_residuals, /*dim=*/-2, /*keepdim=*/false, /*dtype*/real_dtype);
}
solution = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n);
if (m == 0) {
solution.zero_();
}
// for 1-dimensional 'other', we need to squeeze the solution after "apply_lstsq"
if (vector_case) {
solution = solution.squeeze_(-1);
}
}
static std::string get_default_lstsq_driver(c10::optional<std::string> driver, const Tensor& input) {
// if `driver` is empty, we set driver_str to "gels" if working with CUDA tensors,
// otherwise to "gelsy" driver.
// CUDA tensors are treated specially because MAGMA
// has only 'gels' driver supported.
c10::optional<std::string> driver_opt = driver;
std::string driver_str;
// check whether the user provided name is a valid driver name
if (driver.has_value()) {
auto driver_str = driver.value();
driver_str = driver.value();
// convert `driver_str` to lower case inplace.
std::transform(driver_str.begin(), driver_str.end(), driver_str.begin(),
[](unsigned char c) { return std::tolower(c); });
static std::unordered_set<std::string> allowed_drivers = {
"gels", "gelsy", "gelsd", "gelss"
};
if (at::kCPU == self.device().type()) {
if (input.device() == at::kCPU) {
TORCH_CHECK(
allowed_drivers.find(driver_str) != allowed_drivers.end(),
"torch.linalg.lstsq: parameter `driver` should be one of "
"(gels, gelsy, gelsd, gelss)"
);
}
//else if (at::kCUDA == self.device().type()) {
else {
} else { // else if (input.is_cuda())
TORCH_CHECK(
driver_str == "gels",
"torch.linalg.lstsq: `driver` other than `gels` is not supported on CUDA"
);
}
} else {
// if driver name is not provided, set to default 'gelsy' if on CPU,
// or to `gels` if on CUDA.
driver_str = input.is_cuda() ? "gels" : "gelsy";
}
// if driver name is not provided, set to default 'gelsy' if on CPU,
// or to `gels` if on CUDA.
else {
driver_opt = (at::kCPU == self.device().type())
? c10::optional<std::string>("gelsy")
: c10::optional<std::string>("gels");
return driver_str;
}
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
const Tensor& input,
const Tensor& other,
c10::optional<double> rcond,
c10::optional<std::string> driver,
Tensor& solution,
Tensor& residuals,
Tensor& rank,
Tensor& singular_values) {
TORCH_CHECK(input.dim() >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions.");
TORCH_CHECK(other.dim() >= 1, "torch.linalg.lstsq: other must have at least 1 dimension.");
TORCH_CHECK(
input.scalar_type() == other.scalar_type(),
"torch.linalg.lstsq: Expected input and other to have the same dtype, but got input's dtype ",
input.scalar_type(),
" and other's dtype ",
other.scalar_type());
auto dim_diff = input.dim() - other.dim();
TORCH_CHECK(
0 <= dim_diff && dim_diff <= 1,
"torch.linalg.lstsq: input.dim() must be greater or equal to other.dim() and (input.dim() - other.dim()) <= 1");
Tensor other_2d = dim_diff ? other.unsqueeze(-1) : other;
TORCH_CHECK(
input.size(-2) == other_2d.size(-2),
dim_diff ? "torch.linalg.lstsq: input.size(-2) should match other.size(-1)"
: "torch.linalg.lstsq: input.size(-2) should match other.size(-2)");
checkSameDevice("torch.linalg.lstsq", other, input, "other");
checkSameDevice("torch.linalg.lstsq", solution, input, "solution");
checkSameDevice("torch.linalg.lstsq", residuals, input, "residuals");
checkSameDevice("torch.linalg.lstsq", rank, input, "rank");
checkSameDevice("torch.linalg.lstsq", singular_values, input, "singular_values");
// 'solution' is expected to have same dtype as input
checkLinalgCompatibleDtype("torch.linalg.lstsq", solution, input, "solution");
// 'residuals' is expected to have real float dtype
ScalarType real_dtype = c10::toValueType(input.scalar_type());
checkLinalgCompatibleDtype("torch.linalg.lstsq", residuals.scalar_type(), real_dtype, "solution");
// 'rank' is expected to have integer dtype
// actual LAPACK calls use int32_t type for rank, but we promote it to int64_t
// to be consistent with torch.linalg.matrix_rank output dtype
ScalarType rank_expected_type = ScalarType::Long;
checkLinalgCompatibleDtype("torch.linalg.lstsq", rank.scalar_type(), rank_expected_type, "rank");
// 'singular_values' is expected to have real float dtype
checkLinalgCompatibleDtype("torch.linalg.lstsq", singular_values.scalar_type(), real_dtype, "singular_values");
std::string driver_name = get_default_lstsq_driver(driver, input);
// set default rcond value
// TODO: Change this to match non-legacy NumPy behaviour
double rcond_value = rcond.has_value() && (rcond.value() > 0)
? rcond.value()
: _get_epsilon(c10::toValueType(input.scalar_type()));
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));
// now check whether the provided output tensors can be used directly
// Two types of 'other' tensors are supported:
// - 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
// - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
// original torch.lstsq supported only the matrix case, while NumPy works for both cases
// for the batched input we need to be able to distinguish them
// auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
// bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
bool vector_case = linalg_solve_is_vector_rhs(input, other);
// provided output tensor can be used directly if:
// 1. the shape matches the expected shape
// 2. the dtype matches the expected dtype
// 3. the tensor is contiguous
// Checks for the 'solution' tensor
std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
// the actual shape of the shape of the solution returned in (*, n,) or (*, n, nrhs)
// but LAPACK requires extra dimensions so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
expected_solution_shape.push_back(std::max(input.size(-1), input.size(-2)));
if (!vector_case && other.dim() > 2) {
expected_solution_shape.push_back(other.size(-1));
}
// CUDA has only `gels` driver now which ONLY works with overdetermined systems
if (at::kCUDA == self.device().type()) {
TORCH_CHECK(
self.size(-2) >= self.size(-1),
"torch.linalg.lstsq: only overdetermined systems (m >= n) are allowed on CUDA"
);
bool solution_equal_expected_shape = solution.sizes().equals(expected_solution_shape);
bool solution_input_same_type = (solution.scalar_type() == input.scalar_type());
bool is_solution_batched_column_major = false;
if (vector_case) {
is_solution_batched_column_major = solution.is_contiguous();
} else if (!vector_case && solution.dim() >= 2) {
is_solution_batched_column_major = solution.transpose(-2, -1).is_contiguous();
}
// LAPACK/MAGMA requries inputs to be in the column-major-order.
auto self_working_copy = copyBatchedColumnMajor(self);
// 'residuals' is not checked here because at::sum_out(residuals, ...) does that
// Tensor b must be of size (..., max(m, n), nrhs)
// and in the column-major order.
// We allow the batch dims of `self` to broadcast over the batch
// dims of `b` so that it is possible to solve multiple systems with
// with the same lhs (encoded by `self`) / rhs (encoded by `b`).
// `b_working_copy` is modified in-place and the combination of
// batch broadcasting plus LAPACK/MAGMA requirements impose the following
// restrictions on sizes/strides of `b`:
// 1. b.size = (broadcasted_batch_size(self, b), max(m, n), nrhs).
// 2. b.stride should correspond to an almost contiguous Tensor in the column-major-order,
// i.e. b.stride = b.transpose(-2, -1).contiguous().transpose(-2, -1).strides()
auto m = self.size(-2);
auto n = self.size(-1);
auto b_working_copy = copyBatchedColumnMajor(b_2d,
/*nrows=*/std::max(m, n),
/*desired_batch_sizes=*/broadcast_batch_size(self, b_2d, self.dim() - 2));
auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
double rcond = cond.has_value() && (cond.value() > 0)
? cond.value()
: _get_epsilon(c10::toValueType(self.scalar_type()));
auto batch_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend() - 2);
Tensor rank = at::empty({0}, self.options().dtype(at::kLong));
if (driver_opt.value() != "gels") {
rank.resize_(batch_shape, MemoryFormat::Contiguous);
// Checks for the 'rank' tensor
// rank is a scalar value for each matrix in the batch so
// rank's expected shape is equal to input.shape[0:input.ndim-2]
bool rank_equal_expected_shape = true;
bool rank_equal_expected_type = true;
bool rank_is_contiguous = true;
if (driver_name != "gels") { // gels driver doesn't set 'rank'
rank_equal_expected_shape = rank.sizes().equals(input_batch_shape);
rank_equal_expected_type = (rank.scalar_type() == at::kLong);
rank_is_contiguous = rank.is_contiguous();
}
auto singular_values_shape = batch_shape.vec();
singular_values_shape.push_back(std::min(m, n));
auto real_dtype = c10::toValueType(self.scalar_type());
Tensor singular_values = at::empty({0}, self.options().dtype(real_dtype));
if (driver_opt.value() == "gelsd" || driver_opt.value() == "gelss") {
singular_values.resize_(singular_values_shape, MemoryFormat::Contiguous);
// Checks for the 'singular_values' tensor
// singular values are computed only with "gelsd" and "gelss" drivers currently
bool singular_values_equal_expected_shape = true;
bool singular_values_equal_expected_type = true;
bool singular_values_is_contiguous = true;
if (driver_name == "gelsd" || driver_name == "gelss") {
auto singular_values_shape = input_batch_shape.vec();
singular_values_shape.push_back(std::min(input.size(-1), input.size(-2)));
singular_values_equal_expected_shape = singular_values.sizes().equals(singular_values_shape);
singular_values_equal_expected_type = (singular_values.scalar_type() == real_dtype);
singular_values_is_contiguous = singular_values.is_contiguous();
}
Tensor infos = at::zeros({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt).device(kCPU));
// if solution is not empty and not in batched column major format
bool copy_needed = (solution.numel() != 0 && !is_solution_batched_column_major);
copy_needed |= !solution_input_same_type; // or solution does not have the same dtype as input
copy_needed |= (solution.numel() != 0 && !solution_equal_expected_shape); // or solution does not have the expected shape
Tensor x, residuals;
copy_needed |= !rank_equal_expected_type;
copy_needed |= (rank.numel() != 0 && !rank_equal_expected_shape);
copy_needed |= (rank.numel() != 0 && !rank_is_contiguous);
// path if neither `self` nor `b` is empty
if (self.numel() && b.numel()) {
x = at::_lstsq_helper_(b_working_copy, rank, singular_values, infos, self_working_copy, rcond, driver_opt.value());
if (m > n && driver_opt.value() != "gelsy") {
residuals = x.narrow(-2, n, std::max(m, n) - n).abs().pow_(2).sum(-2);
}
x = x.narrow(-2, 0, n);
}
// if either `self` or `b` is empty, return an empty tensor or,
// if non-zero sizes, return a tensor of zeros.
else {
x = b_working_copy.zero_().narrow(-2, 0, n);
copy_needed |= !singular_values_equal_expected_type;
copy_needed |= (singular_values.numel() != 0 && !singular_values_equal_expected_shape);
copy_needed |= (singular_values.numel() != 0 && !singular_values_is_contiguous);
if (copy_needed) { // we have to allocate temporary tensors
Tensor solution_tmp = at::empty({0}, input.options());
Tensor residuals_tmp = at::empty({0}, input.options().dtype(real_dtype));
Tensor rank_tmp = at::empty({0}, input.options().dtype(at::kLong));
Tensor singular_values_tmp = at::empty({0}, input.options().dtype(real_dtype));
linalg_lstsq_out_info(solution_tmp, residuals_tmp, rank_tmp, singular_values_tmp, infos, input, other, rcond_value, driver_name);
at::native::resize_output(solution, solution_tmp.sizes());
solution.copy_(solution_tmp);
at::native::resize_output(residuals, residuals_tmp.sizes());
residuals.copy_(residuals_tmp);
at::native::resize_output(rank, rank_tmp.sizes());
rank.copy_(rank_tmp);
at::native::resize_output(singular_values, singular_values_tmp.sizes());
singular_values.copy_(singular_values_tmp);
} else {
// else use the provided output storage directly
linalg_lstsq_out_info(solution, residuals, rank, singular_values, infos, input, other, rcond_value, driver_name);
}
auto return_empty_if_undefined = [&self](Tensor& t,
c10::optional<at::ScalarType> dtype = c10::nullopt,
c10::optional<std::vector<int64_t>> shape = c10::nullopt) {
if (t.defined()) {
return t;
}
else {
auto output_dtype = dtype.has_value() ? dtype.value() : self.scalar_type();
if (shape.has_value()) {
return at::empty(shape.value(), self.options().dtype(output_dtype));
}
else {
return at::empty({0}, self.options().dtype(output_dtype));
}
}
};
// Some output stays undefined for some values of driver.
// Instead of returning undefined tensors which get exposed as
// Nones in the Python interface, we return empty tensors.
// This way we follow the convention of output types in the
// torch.linalg namespace.
// NOTE: we run drivers only if both inputs are non-empty!
// Hence the code below explicitly handles each and every output
// if `self` is empty.
// Numpy and Scipy always return ranks for empty matrices,
// even for drivers which are not rank-revealing.
if (self.numel()) {
rank = return_empty_if_undefined(rank, at::kLong);
}
else {
rank = at::zeros(batch_shape, self.options().dtype(at::kLong));
}
// undefined residuals could only be an empty Tensor of shape (0)
residuals = return_empty_if_undefined(residuals);
if (!self.numel()
&& (driver_opt.value() == "gelss" || driver_opt.value() == "gelsd")) {
// when `self` is empty, return singular_values of shape
// (*self.shape[:-2], 0) only if driver is in ('gelss', 'gelsd')
auto singular_values_empty_shape = batch_shape.vec();
singular_values_empty_shape.push_back(0);
singular_values = return_empty_if_undefined(
singular_values,
at::toValueType(self.scalar_type()),
singular_values_empty_shape);
}
else {
// otherwise return an empty tensor of shape (0)
singular_values = return_empty_if_undefined(
singular_values,
at::toValueType(self.scalar_type()));
}
if (self.dim() > 2) {
if (infos.numel() > 1) {
batchCheckErrors(infos, "torch.linalg.lstsq");
} else {
singleCheckErrors(infos.item().toInt(), "torch.linalg.lstsq");
singleCheckErrors(infos.item<int64_t>(), "torch.linalg.lstsq");
}
return std::make_tuple(x, residuals, rank, singular_values);
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(solution, residuals, rank, singular_values);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> linalg_lstsq(
const Tensor& input, const Tensor& other,
c10::optional<double> rcond,
c10::optional<std::string> driver) {
Tensor solution = at::empty({0}, input.options());
Tensor residuals = at::empty({0}, input.options().dtype(toValueType(input.scalar_type())));
Tensor rank = at::empty({0}, input.options().dtype(at::kLong));
Tensor singular_values = at::empty({0}, input.options().dtype(toValueType(input.scalar_type())));
std::tie(solution, residuals, rank, singular_values) =
at::linalg_lstsq_outf(input, other, rcond, driver, solution, residuals, rank, singular_values);
return std::make_tuple(solution, residuals, rank, singular_values);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -514,4 +514,20 @@ static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Scalar
out_name, " with dtype ", out_type);
}
/*
Two types of 'other' tensors are supported when solving
a system of linear equations matmul(input, x) = other:
* 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
* 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
The original torch.solve supported only the matrix case, while NumPy works for both cases.
For the batched input we need to be able to distinguish them.
Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
*/
static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
return vector_case;
}
}} // namespace at::native

View File

@ -2667,6 +2667,11 @@ TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
auto trans = MagmaNoTrans;
auto m = magma_int_cast(a.size(-2), "m");
auto n = magma_int_cast(a.size(-1), "n");
TORCH_CHECK(
m >= n,
"torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA");
auto nrhs = magma_int_cast(b.size(-1), "nrhs");
auto ldda = std::max<magma_int_t>(1, m);
auto lddb = std::max<magma_int_t>(1, std::max(m, n));

View File

@ -8637,6 +8637,12 @@
dispatch:
CompositeExplicitAutograd: linalg_lstsq
- func: linalg_lstsq.out(Tensor self, Tensor b, float? cond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
python_module: linalg
variants: function
dispatch:
CPU, CUDA: linalg_lstsq_out
- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float cond, str driver_name) -> Tensor(a!)
variants: function
dispatch:

View File

@ -132,7 +132,7 @@ class TestLinalg(TestCase):
sol2 = a.pinverse() @ b
self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
def check_correctness_ref(a, b, res, ref):
def check_correctness_ref(a, b, res, ref, driver="default"):
def apply_if_not_empty(t, f):
if t.numel():
return f(t)
@ -157,18 +157,34 @@ class TestLinalg(TestCase):
rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1))
singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1])
for i in range(batch_size):
sol, residuals, rank, singular_values = ref(
a_3d.select(0, i).numpy(),
b_3d.select(0, i).numpy()
)
# Singular values are None when lapack_driver='gelsy' in SciPy
if singular_values is None:
singular_values = []
self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
self.assertEqual(residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5)
self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
if a.numel() > 0:
for i in range(batch_size):
sol, residuals, rank, singular_values = ref(
a_3d.select(0, i).numpy(),
b_3d.select(0, i).numpy()
)
# Singular values are None when lapack_driver='gelsy' in SciPy
if singular_values is None:
singular_values = []
self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
self.assertEqual(residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5)
self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
else:
self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs))
self.assertEqual(res.rank.shape, a.shape[:-2])
# residuals are not always computed (and have non-zero shape)
if m > n and driver != "gelsy":
self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0))
else:
self.assertEqual(res.residuals.shape, (0, ))
# singular_values are not always computed (and have non-zero shape)
if driver == "default" or driver == "gelsd" or driver == "gelss":
self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n)))
else:
self.assertEqual(res.singular_values.shape, (0, ))
def check_correctness_scipy(a, b, res, driver, cond):
if TEST_SCIPY and driver not in (None, 'gels'):
@ -176,7 +192,7 @@ class TestLinalg(TestCase):
def scipy_ref(a, b):
return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
check_correctness_ref(a, b, res, scipy_ref)
check_correctness_ref(a, b, res, scipy_ref, driver=driver)
def check_correctness_numpy(a, b, res, driver, cond):
if driver in ('gelsd', 'gelss'):
@ -317,16 +333,16 @@ class TestLinalg(TestCase):
a = torch.rand(2, 3, dtype=dtype, device=device)
b = torch.rand(3, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, 'input `self` Tensor should be at least 2D'):
with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'):
torch.linalg.lstsq(b, b)
with self.assertRaisesRegex(RuntimeError, 'input `b` Tensor should be at least 1D'):
with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'):
torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device))
with self.assertRaisesRegex(RuntimeError, r'self.size\(-2\) should match b.size\(-1\)'):
with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'):
torch.linalg.lstsq(a, b)
with self.assertRaisesRegex(RuntimeError, r'self.size\(-2\) should match b.size\(-2\)'):
with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
torch.linalg.lstsq(a, b.unsqueeze(-1))
def complement_device(device):
@ -338,11 +354,11 @@ class TestLinalg(TestCase):
a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device))
if a.device != b.device:
with self.assertRaisesRegex(RuntimeError, 'input tensors should be on the same device'):
with self.assertRaisesRegex(RuntimeError, 'be on the same device'):
torch.linalg.lstsq(a, b)
b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long()
with self.assertRaisesRegex(RuntimeError, 'input tensors should be of the same dtype'):
with self.assertRaisesRegex(RuntimeError, 'the same dtype'):
torch.linalg.lstsq(a, b)
a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
@ -359,7 +375,7 @@ class TestLinalg(TestCase):
if device != 'cpu':
a = torch.rand(2, 3, dtype=dtype, device=device)
b = torch.rand(2, 1, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems \(m >= n\) are allowed on CUDA'):
with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'):
torch.linalg.lstsq(a, b)
@skipCUDAIfNoMagma

View File

@ -3712,7 +3712,7 @@ op_db: List[OpInfo] = [
aten_name='linalg_lstsq',
op=torch.linalg.lstsq,
dtypes=floating_and_complex_types(),
supports_out=False,
supports_out=True,
sample_inputs_func=sample_inputs_linalg_lstsq,
check_batched_grad=False,
check_batched_gradgrad=False,