Rename btrifact* to lu (#18435)

Summary:
Changelog:

- Renames `btrifact` and `btrifact_with_info` to `lu`to remain consistent with other factorization methods (`qr` and `svd`).
- Now, we will only have one function and methods named `lu`, which performs `lu` decomposition. This function takes a get_infos kwarg, which when set to True includes a infos tensor in the tuple.
- Rename all tests, fix callsites
- Create a tentative alias for `lu` under the name `btrifact` and `btrifact_with_info`, and add a deprecation warning to not promote usage.
- Add the single batch version for `lu` so that users don't have to unsqueeze and squeeze for a single square matrix (see changes in determinant computation in `LinearAlgebra.cpp`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18435

Differential Revision: D14680352

Pulled By: soumith

fbshipit-source-id: af58dfc11fa53d9e8e0318c720beaf5502978cd8
This commit is contained in:
Vishwak Srinivasan
2019-03-29 00:27:48 -07:00
committed by Facebook Github Bot
parent c21e763cd6
commit d859031ebf
20 changed files with 351 additions and 281 deletions

View File

@ -700,8 +700,6 @@ class CAFFE2_API Tensor {
std::tuple<Tensor,Tensor> geqrf() const;
Tensor orgqr(const Tensor & input2) const;
Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const;
std::tuple<Tensor,Tensor> btrifact(bool pivot=true) const;
std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(bool pivot=true) const;
Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const;
Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const;
Tensor lgamma() const;

View File

@ -1171,12 +1171,6 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const {
inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const {
return type().ormqr(*this, input2, input3, left, transpose);
}
inline std::tuple<Tensor,Tensor> Tensor::btrifact(bool pivot) const {
return type().btrifact(*this, pivot);
}
inline std::tuple<Tensor,Tensor,Tensor> Tensor::btrifact_with_info(bool pivot) const {
return type().btrifact_with_info(*this, pivot);
}
inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const {
return type().btrisolve(*this, LU_data, LU_pivots);
}

View File

@ -578,8 +578,6 @@ struct CAFFE2_API Type {
virtual std::tuple<Tensor,Tensor> geqrf(const Tensor & self) const = 0;
virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0;
virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0;
virtual std::tuple<Tensor,Tensor> btrifact(const Tensor & self, bool pivot) const = 0;
virtual std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(const Tensor & self, bool pivot) const = 0;
virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0;
virtual Tensor lgamma(const Tensor & self) const = 0;

View File

@ -88,6 +88,7 @@ _(aten, _log10) \
_(aten, _log1p) \
_(aten, _log2) \
_(aten, _logspace) \
_(aten, _lu_with_info) \
_(aten, _masked_scale) \
_(aten, _mm) \
_(aten, _mv) \
@ -224,8 +225,6 @@ _(aten, bincount) \
_(aten, blackman_window) \
_(aten, bmm) \
_(aten, broadcast_tensors) \
_(aten, btrifact) \
_(aten, btrifact_with_info) \
_(aten, btrisolve) \
_(aten, cartesian_prod) \
_(aten, cat) \

View File

@ -51,8 +51,8 @@ void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b,
}
template<class scalar_t>
void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
AT_ERROR("getrf only takes float or double Tensors");
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
@ -92,11 +92,11 @@ template<> void lapackGetri<float>(int n, float *a, int lda, int *ipiv, float *w
sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
}
template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
template<> void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
dgetrf_(&m, &n, a, &lda, ipiv, info);
}
template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
template<> void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
sgetrf_(&m, &n, a, &lda, ipiv, info);
}
@ -219,7 +219,7 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
for (int64_t i = 0; i < batch_size; i++) {
int info;
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
lapackGetrf<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
infos[i] = info;
if (info != 0) {
return;
@ -406,41 +406,44 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
return result;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<typename scalar_t>
static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos) {
#ifndef USE_LAPACK
AT_ERROR("btrifact: LAPACK library not found in compilation");
AT_ERROR("lu: LAPACK library not found in compilation");
#else
auto self_data = self.data<scalar_t>();
auto self_matrix_stride = matrixStride(self);
auto batch_size = batchCount(self);
auto pivots_data = pivots.data<int>();
auto pivots_matrix_stride = pivots.size(-1);
auto infos_data = infos.data<int>();
auto n = self.size(-1);
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
int* infos_working_ptr = &infos_data[i];
lapackGetrf<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
if (self.dim() == 2) {
lapackLu<scalar_t>(n, n, self_data, n, pivots_data, infos_data);
} else {
auto self_matrix_stride = matrixStride(self);
auto batch_size = batchCount(self);
auto pivots_matrix_stride = pivots.size(-1);
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
int* infos_working_ptr = &infos_data[i];
lapackLu<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
}
}
#endif
}
std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool pivot) {
AT_CHECK(pivot, "btrifact without pivoting is not implemented on the CPU");
AT_CHECK(self.dim() > 2,
"expected tensor with more than 2 dimensions, got size: ", self.sizes(),
std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cpu(const Tensor& self, bool pivot, bool check_errors) {
AT_CHECK(pivot, "lu without pivoting is not implemented on the CPU");
AT_CHECK(self.dim() >= 2,
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
" instead");
squareCheckInputs(self);
auto req_size = self.sizes().vec();
req_size.pop_back();
auto pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt));
req_size.pop_back();
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
@ -449,55 +452,20 @@ std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool
self_working_copy = at::empty_like(self);
} else {
self_working_copy = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cpu", [&]{
apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cpu", [&]{
apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
});
}
if (check_errors) {
if (self.dim() == 2) {
singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
} else {
batchCheckErrors(infos_tensor, "lu");
}
}
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
}
std::tuple<Tensor, Tensor> btrifact(const Tensor& self, bool pivot) {
Tensor LU_fact, pivots, infos;
std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
batchCheckErrors(infos, "btrifact");
return std::make_tuple(LU_fact, pivots);
}
std::tuple<Tensor&, Tensor&> btrifact_out(
Tensor& A_LU,
Tensor& pivots,
const Tensor& self,
bool pivot) {
Tensor infos, A_LU_tmp, pivots_tmp;
std::tie(A_LU_tmp, pivots_tmp, infos) = at::_btrifact_helper(self, pivot);
batchCheckErrors(infos, "btrifact");
A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
return std::tuple<Tensor&, Tensor&>(A_LU, pivots);
}
std::tuple<Tensor, Tensor, Tensor> btrifact_with_info(
const Tensor& self,
bool pivot) {
Tensor LU_fact, pivots, infos;
std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
return std::make_tuple(LU_fact, pivots, infos);
}
std::tuple<Tensor&, Tensor&, Tensor&> btrifact_with_info_out(
Tensor& A_LU,
Tensor& pivots,
Tensor& info,
const Tensor& self,
bool pivot) {
Tensor info_tmp, A_LU_tmp, pivots_tmp;
std::tie(A_LU_tmp, pivots_tmp, info_tmp) = at::_btrifact_helper(self, pivot);
A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
info.resize_as_(info_tmp).copy_(info_tmp);
return std::tuple<Tensor&, Tensor&, Tensor&>(A_LU, pivots, info);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t, bool upper>

View File

@ -21,10 +21,8 @@ namespace native {
// where info helps us identify singular matrices.
static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor& self) {
Tensor p, lu, info;
std::tie(lu, p, info) = self.unsqueeze(0).btrifact_with_info();
p.squeeze_(0);
lu.squeeze_(0);
int int_info = info.squeeze_().item<int32_t>();
std::tie(lu, p, info) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
int int_info = info.item<int32_t>();
AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
auto n = self.size(0);
auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0);

View File

@ -109,7 +109,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
}
// Validates input shapes for operations on batches of square matrices (inverse, cholesky)
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, lu)
static inline void squareCheckInputs(const Tensor& self) {
AT_CHECK(self.size(-1) == self.size(-2),
"A must be batches of square matrices, "

View File

@ -35,18 +35,32 @@ void magmaSolveBatched(
}
template<class scalar_t>
void magmaGetrfBatched(
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
AT_ERROR("getrf only takes float or double Tensors");
void magmaLu(
magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
magma_int_t* ipiv, magma_int_t* info) {
AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
void magmaGetrfNoPivBatched(
void magmaLuBatched(
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
void magmaLuNoPiv(
magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
magma_int_t* info) {
AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
void magmaLuNoPivBatched(
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
AT_ERROR("getrf only takes float or double Tensors");
AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
@ -131,7 +145,21 @@ void magmaSolveBatched<float>(
}
template<>
void magmaGetrfBatched<double>(
void magmaLu<double>(
magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
magma_int_t* ipiv, magma_int_t* info) {
magma_dgetrf_gpu(m, n, dA, ldda, ipiv, info);
}
template<>
void magmaLu<float>(
magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
magma_int_t* ipiv, magma_int_t* info) {
magma_sgetrf_gpu(m, n, dA, ldda, ipiv, info);
}
template<>
void magmaLuBatched<double>(
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
@ -139,7 +167,7 @@ void magmaGetrfBatched<double>(
}
template<>
void magmaGetrfBatched<float>(
void magmaLuBatched<float>(
magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
@ -147,14 +175,28 @@ void magmaGetrfBatched<float>(
}
template<>
void magmaGetrfNoPivBatched<double>(
void magmaLuNoPiv<double>(
magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
magma_int_t* info) {
magma_dgetrf_nopiv_gpu(m, n, dA, ldda, info);
}
template<>
void magmaLuNoPiv<float>(
magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
magma_int_t* info) {
magma_sgetrf_nopiv_gpu(m, n, dA, ldda, info);
}
template<>
void magmaLuNoPivBatched<double>(
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
magma_dgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
}
template<>
void magmaGetrfNoPivBatched<float>(
void magmaLuNoPivBatched<float>(
magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
magma_sgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
@ -373,7 +415,7 @@ AT_ERROR("inverse: MAGMA library not found in "
}
MAGMAQueue magma_queue(self.get_device());
magmaGetrfBatched<scalar_t>(
magmaLuBatched<scalar_t>(
n, n, self_array, n, ipiv_array, info_array,
batch_size, magma_queue);
@ -527,75 +569,96 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
}
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) {
#ifndef USE_MAGMA
AT_ERROR("btrifact: MAGMA library not found in "
AT_ERROR("lu: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
auto self_data = self.data<scalar_t>();
auto self_matrix_stride = matrixStride(self);
magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
magma_int_t n = magma_int_cast(self.size(-1), "n");
scalar_t** self_array;
ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
if (self.dim() == 2) {
// If `pivots` is defined, then we have to compute them.
// We will use the normal getrf function to compute the LU factorization
// and the pivots
// We create temporary tensors on the CPU, because tensors on the GPU
// cause segfault when passed to magmaLu and magmaLuNoPiv. The data is later
// copied to the appropriate tensors.
Tensor info_tmp = at::zeros({}, at::kInt);
if (get_pivots) {
Tensor piv_tmp = at::empty({n}, at::kInt);
magmaLu<scalar_t>(
n, n, self_data, n, piv_tmp.data<magma_int_t>(), info_tmp.data<magma_int_t>());
pivots.copy_(piv_tmp);
} else {
magmaLuNoPiv<scalar_t>(n, n, self_data, n, info_tmp.data<magma_int_t>());
}
infos.copy_(info_tmp);
} else {
auto self_matrix_stride = matrixStride(self);
magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
self_array[i] = &self_data[i * self_matrix_stride];
}
scalar_t** self_array;
ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
MAGMAQueue magma_queue(self.get_device());
// If `pivots` is defined, then we have to compute them.
// We will use the normal getrf function to compute the LU factorization
// and the pivots
if (pivots.defined()) {
auto pivots_data = pivots.data<magma_int_t>();
auto pivots_matrix_stride = pivots.size(-1);
magma_int_t** pivots_array;
ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
self_array[i] = &self_data[i * self_matrix_stride];
}
magmaGetrfBatched<scalar_t>(
n, n, self_array, n, pivots_array,
infos.data<magma_int_t>(), batch_size, magma_queue);
} else {
magmaGetrfNoPivBatched<scalar_t>(
n, n, self_array, n, infos.data<magma_int_t>(),
batch_size, magma_queue);
MAGMAQueue magma_queue(self.get_device());
// Same comment as in the case of single matrix above.
if (get_pivots) {
auto pivots_data = pivots.data<magma_int_t>();
auto pivots_matrix_stride = pivots.size(-1);
magma_int_t** pivots_array;
ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
for (int64_t i = 0; i < batch_size; i++) {
pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
}
magmaLuBatched<scalar_t>(
n, n, self_array, n, pivots_array,
infos.data<magma_int_t>(), batch_size, magma_queue);
} else {
magmaLuNoPivBatched<scalar_t>(
n, n, self_array, n, infos.data<magma_int_t>(),
batch_size, magma_queue);
}
}
#endif
}
std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cuda(const Tensor& self, bool pivot) {
AT_CHECK(self.dim() > 2,
"expected tensor with more than 2 dimensions, got size: ", self.sizes(),
std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool pivot, bool check_errors) {
AT_CHECK(self.dim() >= 2,
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
" instead");
squareCheckInputs(self);
auto req_size = self.sizes().vec();
req_size.pop_back();
Tensor pivots_tensor;
if (pivot) {
pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
}
Tensor pivots_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
req_size.pop_back();
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
auto infos_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
Tensor self_working_copy;
if (self.numel() == 0) {
self_working_copy = at::empty_like(self);
} else {
self_working_copy = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cuda", [&]{
apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cuda", [&]{
apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor, pivot);
});
}
if (check_errors) {
if (self.dim() == 2) {
singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
} else {
batchCheckErrors(infos_tensor, "lu");
}
}
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
}

View File

@ -3818,26 +3818,12 @@
matches_jit_signature: True
variants: method, function
- func: btrifact(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots) -> (Tensor(a!), Tensor(b!))
matches_jit_signature: True
- func: btrifact(Tensor self, *, bool pivot=True) -> (Tensor, Tensor)
matches_jit_signature: True
variants: method, function
- func: btrifact_with_info(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!), Tensor(b!), Tensor(c!))
matches_jit_signature: True
- func: btrifact_with_info(Tensor self, *, bool pivot=True) -> (Tensor, Tensor, Tensor)
matches_jit_signature: True
variants: method, function
- func: _btrifact_helper(Tensor self, bool pivot) -> (Tensor, Tensor, Tensor)
- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
CPU: _btrifact_helper_cpu
CUDA: _btrifact_helper_cuda
CPU: _lu_with_info_cpu
CUDA: _lu_with_info_cuda
- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True

View File

@ -306,6 +306,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: long
.. automethod:: lt
.. automethod:: lt_
.. automethod:: lu
.. automethod:: map_
.. automethod:: masked_scatter_
.. automethod:: masked_scatter

View File

@ -315,6 +315,7 @@ BLAS and LAPACK Operations
.. autofunction:: det
.. autofunction:: logdet
.. autofunction:: slogdet
.. autofunction:: lu
.. autofunction:: matmul
.. autofunction:: matrix_power
.. autofunction:: matrix_rank

View File

@ -2361,8 +2361,8 @@ class TestCuda(TestCase):
@skipIfRocm
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_btrifact(self):
_TestTorchMixin._test_btrifact(self, lambda t: t.cuda())
def test_lu(self):
_TestTorchMixin._test_lu(self, lambda t: t.cuda())
@skipIfRocm
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")

View File

@ -1746,44 +1746,43 @@ class _TestTorchMixin(object):
_test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64))
@staticmethod
def _test_btrifact(self, cast):
def _test_lu(self, cast):
from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
def run_test(matrix_size, batches, cast):
a = cast(fullrank(matrix_size, *batches))
a_LU_info, pivots_info, info_ = a.btrifact_with_info()
a_LU_info, pivots_info, info_ = a.lu(get_infos=True)
self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
self.assertEqual(info_.size(), torch.Size(batches))
self.assertEqual(info_.abs().sum(), 0)
a_LU, pivots = a.btrifact()
a_LU, pivots = a.lu()
self.assertEqual(a_LU, a_LU_info)
self.assertEqual(pivots_info, pivots)
if a.is_cuda:
a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False)
self.assertIsNone(nopiv)
a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True)
self.assertEqual(nopiv, cast(torch.zeros(a.shape[:-1], dtype=torch.int32)))
self.assertEqual(info_, info_nopiv)
P, L, U = torch.btriunpack(a_LU, pivots)
self.assertEqual(P.matmul(L.matmul(U)), a)
for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]):
for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]):
run_test(ms, batch, cast)
# Info should be positive for rank deficient matrices
a = cast(fullrank(3, 5))
a = cast(torch.ones(5, 3, 3))
if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])):
a[0, 1] = 2 * a[0, 0] # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency
self.assertGreater(a.btrifact_with_info()[2][0], 0)
self.assertGreater(a.lu(get_infos=True)[2][0], 0)
# Error checking, no pivoting variant on CPU
with self.assertRaisesRegex(RuntimeError,
'btrifact without pivoting is not implemented on the CPU'):
torch.btrifact(torch.empty(1, 2, 2), pivot=False)
'lu without pivoting is not implemented on the CPU'):
torch.lu(torch.empty(1, 2, 2), pivot=False)
@skipIfNoLapack
@skipIfRocm
def test_btrifact(self):
self._test_btrifact(self, lambda t: t)
def test_lu(self):
self._test_lu(self, lambda t: t)
@staticmethod
def _test_btrisolve(self, cast):
@ -1797,7 +1796,7 @@ class _TestTorchMixin(object):
(-1.56, 4.00),
(9.81, -4.09)))
a, b = cast(a), cast(b)
LU_data, pivots, info = a.btrifact_with_info()
LU_data, pivots, info = a.lu(get_infos=True)
self.assertEqual(info.abs().sum(), 0)
x = torch.btrisolve(b, LU_data, pivots)
b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
@ -1811,12 +1810,11 @@ class _TestTorchMixin(object):
def _test_btriunpack(self, cast):
def run_test(shape, cast):
a = cast(torch.randn(*shape))
a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1]))
a_lu = a_lu.reshape_as(a)
p = p.reshape(a.shape[:-1])
a_lu, p = torch.lu(a)
p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
run_test((3, 3), cast)
run_test((5, 3, 3), cast)
run_test((7, 3, 5, 5), cast)
run_test((7, 5, 3, 3, 3), cast)
@ -4743,11 +4741,14 @@ class _TestTorchMixin(object):
A = torch.randn(3, 3, device=A_device)
err_str = "Expected b and A to be on the same device"
with self.assertRaisesRegex(RuntimeError, err_str):
torch.gesv(b, A)
torch.solve(b, A)
with self.assertRaisesRegex(RuntimeError, err_str):
torch.cholesky_solve(b, A)
with self.assertRaisesRegex(RuntimeError, err_str):
torch.triangular_solve(b, A)
@skipIfNoLapack
def test_qr(self):
@ -7965,12 +7966,12 @@ class _TestTorchMixin(object):
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
if torch._C.has_lapack:
# btrifact
A_LU, pivots = fn(torch.btrifact, (0, 5, 5))
# lu
A_LU, pivots = fn(torch.lu, (0, 5, 5))
self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
A_LU, pivots = fn(torch.btrifact, (0, 0, 0))
A_LU, pivots = fn(torch.lu, (0, 0, 0))
self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
A_LU, pivots = fn(torch.btrifact, (2, 0, 0))
A_LU, pivots = fn(torch.lu, (2, 0, 0))
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
@skipIfRocm

View File

@ -192,12 +192,6 @@
self: grad.bmm(mat2.transpose(1, 2))
mat2: self.transpose(1, 2).bmm(grad)
- name: btrifact(Tensor self, bool pivot)
self: not_implemented("btrifact")
- name: btrifact_with_info(Tensor self, bool pivot)
self: not_implemented("btrifact_with_info")
- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
self: not_implemented("btrisolve")
@ -470,6 +464,9 @@
self: zeros_like(self)
other: zeros_like(other)
- name: _lu_with_info(Tensor self, bool pivot, bool check_errors)
self: not_implemented("lu_with_info")
- name: masked_fill_(Tensor self, Tensor mask, Scalar value)
self: grad.clone().masked_fill_(mask, 0)

View File

@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
'_th_.*', '_thnn_.*',
'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
'_cholesky.*', '_btrifact.*', '_triangular_solve.*',
'_cholesky.*', '_triangular_solve.*',
'slice', 'randint(_out)?',
'item', '_local_scalar_dense',
'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',

View File

@ -78,6 +78,7 @@ class Tensor:
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def split(self, split_size, dim=0): ...
def unique(self, sorted=True, return_inverse=False, dim=None): ...
def lu(self, pivot=True, get_infos=False): ...
${function_hints}

View File

@ -486,20 +486,6 @@ bmm(batch2) -> Tensor
See :func:`torch.bmm`
""")
add_docstr_all('btrifact',
r"""
btrifact(pivot=True) -> (Tensor, Tensor)
See :func:`torch.btrifact`
""")
add_docstr_all('btrifact_with_info',
r"""
btrifact_with_info(pivot=True) -> (Tensor, Tensor, Tensor)
See :func:`torch.btrifact_with_info`
""")
add_docstr_all('btrisolve',
r"""
btrisolve(LU_data, LU_pivots) -> Tensor
@ -1019,13 +1005,6 @@ ger(vec2) -> Tensor
See :func:`torch.ger`
""")
add_docstr_all('solve',
r"""
solve(A) -> Tensor, Tensor
See :func:`torch.solve`
""")
add_docstr_all('indices',
r"""
indices() -> Tensor
@ -2228,6 +2207,13 @@ Example::
""")
add_docstr_all('solve',
r"""
solve(A) -> Tensor, Tensor
See :func:`torch.solve`
""")
add_docstr_all('sort',
r"""
sort(dim=-1, descending=False) -> (Tensor, LongTensor)

View File

@ -5516,71 +5516,6 @@ Example::
[ 0., 0., 0.]])
""".format(**factory_like_common_args))
add_docstr(torch.btrifact,
r"""
btrifact(A, pivot=True) -> (Tensor, IntTensor)
Batch LU factorization.
Returns a tuple containing the LU factorization and pivots. Pivoting is done if
:attr:`pivot` is set.
.. note::
LU factorization with :attr:`pivot` = ``True`` is not available for CPU, and attempting
to do so will throw an error. However, LU factorization with :attr:`pivot` = ``True`` is
available for CUDA.
Arguments:
A (Tensor): the tensor to factor
pivot (bool, optional): controls whether pivoting is done
Returns:
A tuple containing factorization and pivots.
Example::
>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.btrifact(A)
>>> A_LU
tensor([[[ 1.3506, 2.5558, -0.0816],
[ 0.1684, 1.1551, 0.1940],
[ 0.1193, 0.6189, -0.5497]],
[[ 0.4526, 1.2526, -0.3285],
[-0.7988, 0.7175, -0.9701],
[ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3, 3, 3],
[ 3, 3, 3]], dtype=torch.int32)
""")
add_docstr(torch.btrifact_with_info,
r"""
btrifact_with_info(A, pivot=True) -> (Tensor, IntTensor, IntTensor)
Batch LU factorization with additional error information.
This is a version of :meth:`torch.btrifact` that always creates an info
`IntTensor`, and returns it as the third return value.
Arguments:
A (Tensor): the tensor to factor
pivot (bool, optional): controls whether pivoting is done
Returns:
A tuple containing factorization, pivots, and an `IntTensor` where non-zero
values indicate whether factorization for each minibatch sample succeeds.
Example::
>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots, info = A.btrifact_with_info()
>>> if info.nonzero().size(0) == 0:
>>> print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!
""")
add_docstr(torch.btrisolve,
r"""
btrisolve(b, LU_data, LU_pivots) -> Tensor
@ -5591,14 +5526,14 @@ Returns the LU solve of the linear system :math:`Ax = b`.
Arguments:
b (Tensor): the RHS tensor
LU_data (Tensor): the pivoted LU factorization of A from :meth:`btrifact`.
LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
LU_pivots (IntTensor): the pivots of the LU factorization
Example::
>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3)
>>> A_LU = torch.btrifact(A)
>>> A_LU = torch.lu(A)
>>> x = torch.btrisolve(b, *A_LU)
>>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
tensor(1.00000e-07 *

View File

@ -6,23 +6,26 @@ import warnings
__all__ = [
'btriunpack',
'broadcast_tensors',
'btrifact',
'btrifact_with_info',
'cartesian_prod',
'chain_matmul',
'einsum',
'broadcast_tensors',
'gesv',
'isfinite',
'isinf',
'lu',
'norm',
'meshgrid',
'potrf',
'pstrf',
'potrs',
'gesv',
'split',
'stft',
'tensordot',
'trtrs',
'unique',
'cartesian_prod',
]
@ -81,7 +84,7 @@ def split(tensor, split_size_or_sections, dim=0):
def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor.
r"""Unpacks the data and pivots from a LU factorization of a tensor.
Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
@ -94,7 +97,7 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
Example::
>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = A.btrifact()
>>> A_LU, pivots = A.lu()
>>> P, A_L, A_U = torch.btriunpack(A_LU, pivots)
>>>
>>> # can recover A from factorization
@ -111,13 +114,20 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
L = U = None
if unpack_pivots:
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
LU_pivots = LU_pivots - 1
for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
LU_pivots_zero_idx = LU_pivots - 1
if LU_data.dim() > 2:
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
final_order = list(range(sz))
for k, j in enumerate(LU_pivots_zero_idx[idx]):
final_order[k], final_order[j] = final_order[j], final_order[k]
P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
else:
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype)
final_order = list(range(sz))
for k, j in enumerate(LU_pivots[idx]):
for k, j, in enumerate(LU_pivots_zero_idx):
final_order[k], final_order[j] = final_order[j], final_order[k]
P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
else:
P = None
@ -751,6 +761,8 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
with the default keyword arguments.
For more information regarding :func:`torch.trtrs`, please check :func:`torch.triangular_solve`.
.. warning::
:func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
removed in the next release. Please use :func:`torch.triangular_solve` instead.
@ -758,3 +770,112 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
"removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
def btrifact(A, pivot=True, out=None):
r"""Returns a tuple containing the LU factorization and pivots of :attr:`A`.
Pivoting is done if :attr:`pivot` is set.
For more information regarding :func:`torch.btrifact`, please check :func:`torch.lu`.
.. warning::
:func:`torch.btrifact` is deprecated in favour of :func:`torch.lu` and will be
removed in the next release. Please use :func:`torch.lu` instead.
"""
warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be "
"removed in the next release. Please use torch.lu instead.", stacklevel=2)
return lu(A, pivot=pivot, get_infos=False, out=out)
def btrifact_with_info(A, pivot=True, out=None):
r"""Performs LU factorization and returns additional status information along with the LU
factorization and pivots.
For more information regarding :func:`torch.btrifact_with_info`, please check :func:`torch.lu`.
.. warning::
:func:`torch.btrifact_with_info` is deprecated in favour of :func:`torch.lu` and will
be removed in the next release. Please use :func:`torch.lu` with the :attr:`get_infos`
argument set to ``True`` instead.
"""
warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu and will be "
"removed in the next release. Please use torch.lu with the get_infos argument "
"set to True instead.",
stacklevel=2)
return lu(A, pivot=pivot, get_infos=True, out=out)
def lu(A, pivot=True, get_infos=False, out=None):
r"""Computes the LU factorization of a square matrix or batches of square matrices
:attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
Pivoting is done if :attr:`pivot` is set to ``True``.
.. note::
The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,
then the returned pivots is a tensor filled with zeros of the appropriate size.
.. note::
LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting
to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is
available for CUDA.
.. note::
This function does not check if the factorization was successful or not if
:attr:`get_infos` is ``True`` since the status of the factorization is present in the
third element of the return tuple.
Arguments:
A (Tensor): the tensor to factor of size :math:`(*, m, m)`
pivot (bool, optional): controls whether pivoting is done. Default: ``True``
get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
Default: ``False``
out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
then the elements in the tuple are Tensor, IntTensor,
and IntTensor. If :attr:`get_infos` is ``False``, then the
elements in the tuple are Tensor, IntTensor. Default: ``None``
Returns:
(Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
- **factorization** (*Tensor*): the factorization of size :math:`(*, m, m)`
- **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`
- **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
each minibatch has succeeded or failed
Example::
>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506, 2.5558, -0.0816],
[ 0.1684, 1.1551, 0.1940],
[ 0.1193, 0.6189, -0.5497]],
[[ 0.4526, 1.2526, -0.3285],
[-0.7988, 0.7175, -0.9701],
[ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3, 3, 3],
[ 3, 3, 3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
... print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!
"""
# If get_infos is True, then we don't need to check for errors and vice versa
result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
if out is not None:
if not isinstance(out, (tuple, list)):
raise TypeError("argument 'out' must be tuple of Tensors, not {}"
.format(type(out).__name__))
if len(out) - int(get_infos) != 2:
raise TypeError("expected tuple of {} elements but got {}"
.format(2 + int(get_infos), len(out)))
return (out[i].resize_as_(result[i]).copy_(result[i]) for i in range(len(out)))
if get_infos:
return result # A_LU, pivots, infos
else:
return result[0], result[1] # A_LU, pivots

View File

@ -282,10 +282,33 @@ class Tensor(torch._C._TensorBase):
def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
r"""See :func:`torch.triangular_solve`"""
warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
"removed in the next release. Please use torch.triangular_solve.", stacklevel=2)
"removed in the next release. Please use torch.triangular_solve instead.",
stacklevel=2)
return super(Tensor, self).triangular_solve(A, upper=upper,
transpose=transpose, unitriangular=unitriangular)
def btrifact(self, pivot=True):
r"""See :func:`torch.lu`"""
warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be removed in "
"the next release. Please use torch.lu instead.", stacklevel=2)
return torch._lu_with_info(self, pivot=pivot, check_errors=True)
def btrifact_with_info(self, pivot=True):
r"""See :func:`torch.lu`"""
warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
"and will be removed in the next release. Please use torch.lu with the "
"get_infos argument set to True instead.", stacklevel=2)
return torch._lu_with_info(self, pivot=pivot, check_errors=False)
def lu(self, pivot=True, get_infos=False):
r"""See :func:`torch.lu`"""
# If get_infos is True, then we don't need to check for errors and vice versa
LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
if get_infos:
return LU, pivots, infos
else:
return LU, pivots
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True):
r"""See :func:`torch.stft`