mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename btrifact*
to lu
(#18435)
Summary: Changelog: - Renames `btrifact` and `btrifact_with_info` to `lu`to remain consistent with other factorization methods (`qr` and `svd`). - Now, we will only have one function and methods named `lu`, which performs `lu` decomposition. This function takes a get_infos kwarg, which when set to True includes a infos tensor in the tuple. - Rename all tests, fix callsites - Create a tentative alias for `lu` under the name `btrifact` and `btrifact_with_info`, and add a deprecation warning to not promote usage. - Add the single batch version for `lu` so that users don't have to unsqueeze and squeeze for a single square matrix (see changes in determinant computation in `LinearAlgebra.cpp`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/18435 Differential Revision: D14680352 Pulled By: soumith fbshipit-source-id: af58dfc11fa53d9e8e0318c720beaf5502978cd8
This commit is contained in:
committed by
Facebook Github Bot
parent
c21e763cd6
commit
d859031ebf
@ -700,8 +700,6 @@ class CAFFE2_API Tensor {
|
||||
std::tuple<Tensor,Tensor> geqrf() const;
|
||||
Tensor orgqr(const Tensor & input2) const;
|
||||
Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const;
|
||||
std::tuple<Tensor,Tensor> btrifact(bool pivot=true) const;
|
||||
std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(bool pivot=true) const;
|
||||
Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const;
|
||||
Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const;
|
||||
Tensor lgamma() const;
|
||||
|
@ -1171,12 +1171,6 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const {
|
||||
inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const {
|
||||
return type().ormqr(*this, input2, input3, left, transpose);
|
||||
}
|
||||
inline std::tuple<Tensor,Tensor> Tensor::btrifact(bool pivot) const {
|
||||
return type().btrifact(*this, pivot);
|
||||
}
|
||||
inline std::tuple<Tensor,Tensor,Tensor> Tensor::btrifact_with_info(bool pivot) const {
|
||||
return type().btrifact_with_info(*this, pivot);
|
||||
}
|
||||
inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const {
|
||||
return type().btrisolve(*this, LU_data, LU_pivots);
|
||||
}
|
||||
|
@ -578,8 +578,6 @@ struct CAFFE2_API Type {
|
||||
virtual std::tuple<Tensor,Tensor> geqrf(const Tensor & self) const = 0;
|
||||
virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0;
|
||||
virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0;
|
||||
virtual std::tuple<Tensor,Tensor> btrifact(const Tensor & self, bool pivot) const = 0;
|
||||
virtual std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(const Tensor & self, bool pivot) const = 0;
|
||||
virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
|
||||
virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0;
|
||||
virtual Tensor lgamma(const Tensor & self) const = 0;
|
||||
|
@ -88,6 +88,7 @@ _(aten, _log10) \
|
||||
_(aten, _log1p) \
|
||||
_(aten, _log2) \
|
||||
_(aten, _logspace) \
|
||||
_(aten, _lu_with_info) \
|
||||
_(aten, _masked_scale) \
|
||||
_(aten, _mm) \
|
||||
_(aten, _mv) \
|
||||
@ -224,8 +225,6 @@ _(aten, bincount) \
|
||||
_(aten, blackman_window) \
|
||||
_(aten, bmm) \
|
||||
_(aten, broadcast_tensors) \
|
||||
_(aten, btrifact) \
|
||||
_(aten, btrifact_with_info) \
|
||||
_(aten, btrisolve) \
|
||||
_(aten, cartesian_prod) \
|
||||
_(aten, cat) \
|
||||
|
@ -51,8 +51,8 @@ void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b,
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
|
||||
AT_ERROR("getrf only takes float or double Tensors");
|
||||
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
|
||||
AT_ERROR("lu only takes float or double Tensors");
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
@ -92,11 +92,11 @@ template<> void lapackGetri<float>(int n, float *a, int lda, int *ipiv, float *w
|
||||
sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
|
||||
}
|
||||
|
||||
template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
|
||||
template<> void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
|
||||
dgetrf_(&m, &n, a, &lda, ipiv, info);
|
||||
}
|
||||
|
||||
template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
|
||||
template<> void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
|
||||
sgetrf_(&m, &n, a, &lda, ipiv, info);
|
||||
}
|
||||
|
||||
@ -219,7 +219,7 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
|
||||
for (int64_t i = 0; i < batch_size; i++) {
|
||||
int info;
|
||||
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
|
||||
lapackGetrf<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
|
||||
lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
|
||||
infos[i] = info;
|
||||
if (info != 0) {
|
||||
return;
|
||||
@ -406,41 +406,44 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template<typename scalar_t>
|
||||
static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
|
||||
static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos) {
|
||||
#ifndef USE_LAPACK
|
||||
AT_ERROR("btrifact: LAPACK library not found in compilation");
|
||||
AT_ERROR("lu: LAPACK library not found in compilation");
|
||||
#else
|
||||
auto self_data = self.data<scalar_t>();
|
||||
auto self_matrix_stride = matrixStride(self);
|
||||
auto batch_size = batchCount(self);
|
||||
|
||||
auto pivots_data = pivots.data<int>();
|
||||
auto pivots_matrix_stride = pivots.size(-1);
|
||||
auto infos_data = infos.data<int>();
|
||||
|
||||
auto n = self.size(-1);
|
||||
|
||||
for (int64_t i = 0; i < batch_size; i++) {
|
||||
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
|
||||
int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
|
||||
int* infos_working_ptr = &infos_data[i];
|
||||
lapackGetrf<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
|
||||
if (self.dim() == 2) {
|
||||
lapackLu<scalar_t>(n, n, self_data, n, pivots_data, infos_data);
|
||||
} else {
|
||||
auto self_matrix_stride = matrixStride(self);
|
||||
auto batch_size = batchCount(self);
|
||||
auto pivots_matrix_stride = pivots.size(-1);
|
||||
for (int64_t i = 0; i < batch_size; i++) {
|
||||
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
|
||||
int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
|
||||
int* infos_working_ptr = &infos_data[i];
|
||||
lapackLu<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool pivot) {
|
||||
AT_CHECK(pivot, "btrifact without pivoting is not implemented on the CPU");
|
||||
AT_CHECK(self.dim() > 2,
|
||||
"expected tensor with more than 2 dimensions, got size: ", self.sizes(),
|
||||
std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cpu(const Tensor& self, bool pivot, bool check_errors) {
|
||||
AT_CHECK(pivot, "lu without pivoting is not implemented on the CPU");
|
||||
AT_CHECK(self.dim() >= 2,
|
||||
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
|
||||
" instead");
|
||||
squareCheckInputs(self);
|
||||
auto req_size = self.sizes().vec();
|
||||
req_size.pop_back();
|
||||
auto pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
|
||||
auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt));
|
||||
req_size.pop_back();
|
||||
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
|
||||
|
||||
@ -449,55 +452,20 @@ std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool
|
||||
self_working_copy = at::empty_like(self);
|
||||
} else {
|
||||
self_working_copy = cloneBatchedColumnMajor(self);
|
||||
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cpu", [&]{
|
||||
apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
|
||||
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cpu", [&]{
|
||||
apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
|
||||
});
|
||||
}
|
||||
if (check_errors) {
|
||||
if (self.dim() == 2) {
|
||||
singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
|
||||
} else {
|
||||
batchCheckErrors(infos_tensor, "lu");
|
||||
}
|
||||
}
|
||||
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> btrifact(const Tensor& self, bool pivot) {
|
||||
Tensor LU_fact, pivots, infos;
|
||||
std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
|
||||
batchCheckErrors(infos, "btrifact");
|
||||
return std::make_tuple(LU_fact, pivots);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&> btrifact_out(
|
||||
Tensor& A_LU,
|
||||
Tensor& pivots,
|
||||
const Tensor& self,
|
||||
bool pivot) {
|
||||
Tensor infos, A_LU_tmp, pivots_tmp;
|
||||
std::tie(A_LU_tmp, pivots_tmp, infos) = at::_btrifact_helper(self, pivot);
|
||||
batchCheckErrors(infos, "btrifact");
|
||||
A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
|
||||
pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
|
||||
return std::tuple<Tensor&, Tensor&>(A_LU, pivots);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> btrifact_with_info(
|
||||
const Tensor& self,
|
||||
bool pivot) {
|
||||
Tensor LU_fact, pivots, infos;
|
||||
std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
|
||||
return std::make_tuple(LU_fact, pivots, infos);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> btrifact_with_info_out(
|
||||
Tensor& A_LU,
|
||||
Tensor& pivots,
|
||||
Tensor& info,
|
||||
const Tensor& self,
|
||||
bool pivot) {
|
||||
Tensor info_tmp, A_LU_tmp, pivots_tmp;
|
||||
std::tie(A_LU_tmp, pivots_tmp, info_tmp) = at::_btrifact_helper(self, pivot);
|
||||
A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
|
||||
pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
|
||||
info.resize_as_(info_tmp).copy_(info_tmp);
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&>(A_LU, pivots, info);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template <typename scalar_t, bool upper>
|
||||
|
@ -21,10 +21,8 @@ namespace native {
|
||||
// where info helps us identify singular matrices.
|
||||
static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor& self) {
|
||||
Tensor p, lu, info;
|
||||
std::tie(lu, p, info) = self.unsqueeze(0).btrifact_with_info();
|
||||
p.squeeze_(0);
|
||||
lu.squeeze_(0);
|
||||
int int_info = info.squeeze_().item<int32_t>();
|
||||
std::tie(lu, p, info) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
|
||||
int int_info = info.item<int32_t>();
|
||||
AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
|
||||
auto n = self.size(0);
|
||||
auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0);
|
||||
|
@ -109,7 +109,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
|
||||
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
|
||||
}
|
||||
|
||||
// Validates input shapes for operations on batches of square matrices (inverse, cholesky)
|
||||
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, lu)
|
||||
static inline void squareCheckInputs(const Tensor& self) {
|
||||
AT_CHECK(self.size(-1) == self.size(-2),
|
||||
"A must be batches of square matrices, "
|
||||
|
@ -35,18 +35,32 @@ void magmaSolveBatched(
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
void magmaGetrfBatched(
|
||||
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
|
||||
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
|
||||
const MAGMAQueue& magma_queue) {
|
||||
AT_ERROR("getrf only takes float or double Tensors");
|
||||
void magmaLu(
|
||||
magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
|
||||
magma_int_t* ipiv, magma_int_t* info) {
|
||||
AT_ERROR("lu only takes float or double Tensors");
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
void magmaGetrfNoPivBatched(
|
||||
void magmaLuBatched(
|
||||
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
|
||||
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
|
||||
const MAGMAQueue& magma_queue) {
|
||||
AT_ERROR("lu only takes float or double Tensors");
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
void magmaLuNoPiv(
|
||||
magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
|
||||
magma_int_t* info) {
|
||||
AT_ERROR("lu only takes float or double Tensors");
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
void magmaLuNoPivBatched(
|
||||
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
|
||||
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
|
||||
AT_ERROR("getrf only takes float or double Tensors");
|
||||
AT_ERROR("lu only takes float or double Tensors");
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
@ -131,7 +145,21 @@ void magmaSolveBatched<float>(
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaGetrfBatched<double>(
|
||||
void magmaLu<double>(
|
||||
magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
|
||||
magma_int_t* ipiv, magma_int_t* info) {
|
||||
magma_dgetrf_gpu(m, n, dA, ldda, ipiv, info);
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaLu<float>(
|
||||
magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
|
||||
magma_int_t* ipiv, magma_int_t* info) {
|
||||
magma_sgetrf_gpu(m, n, dA, ldda, ipiv, info);
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaLuBatched<double>(
|
||||
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
|
||||
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
|
||||
const MAGMAQueue& magma_queue) {
|
||||
@ -139,7 +167,7 @@ void magmaGetrfBatched<double>(
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaGetrfBatched<float>(
|
||||
void magmaLuBatched<float>(
|
||||
magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
|
||||
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
|
||||
const MAGMAQueue& magma_queue) {
|
||||
@ -147,14 +175,28 @@ void magmaGetrfBatched<float>(
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaGetrfNoPivBatched<double>(
|
||||
void magmaLuNoPiv<double>(
|
||||
magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
|
||||
magma_int_t* info) {
|
||||
magma_dgetrf_nopiv_gpu(m, n, dA, ldda, info);
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaLuNoPiv<float>(
|
||||
magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
|
||||
magma_int_t* info) {
|
||||
magma_sgetrf_nopiv_gpu(m, n, dA, ldda, info);
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaLuNoPivBatched<double>(
|
||||
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
|
||||
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
|
||||
magma_dgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
|
||||
}
|
||||
|
||||
template<>
|
||||
void magmaGetrfNoPivBatched<float>(
|
||||
void magmaLuNoPivBatched<float>(
|
||||
magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
|
||||
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
|
||||
magma_sgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
|
||||
@ -373,7 +415,7 @@ AT_ERROR("inverse: MAGMA library not found in "
|
||||
}
|
||||
|
||||
MAGMAQueue magma_queue(self.get_device());
|
||||
magmaGetrfBatched<scalar_t>(
|
||||
magmaLuBatched<scalar_t>(
|
||||
n, n, self_array, n, ipiv_array, info_array,
|
||||
batch_size, magma_queue);
|
||||
|
||||
@ -527,75 +569,96 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
|
||||
}
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
|
||||
static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) {
|
||||
#ifndef USE_MAGMA
|
||||
AT_ERROR("btrifact: MAGMA library not found in "
|
||||
AT_ERROR("lu: MAGMA library not found in "
|
||||
"compilation. Please rebuild with MAGMA.");
|
||||
#else
|
||||
auto self_data = self.data<scalar_t>();
|
||||
auto self_matrix_stride = matrixStride(self);
|
||||
magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
|
||||
magma_int_t n = magma_int_cast(self.size(-1), "n");
|
||||
|
||||
scalar_t** self_array;
|
||||
ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
|
||||
if (self.dim() == 2) {
|
||||
// If `pivots` is defined, then we have to compute them.
|
||||
// We will use the normal getrf function to compute the LU factorization
|
||||
// and the pivots
|
||||
// We create temporary tensors on the CPU, because tensors on the GPU
|
||||
// cause segfault when passed to magmaLu and magmaLuNoPiv. The data is later
|
||||
// copied to the appropriate tensors.
|
||||
Tensor info_tmp = at::zeros({}, at::kInt);
|
||||
if (get_pivots) {
|
||||
Tensor piv_tmp = at::empty({n}, at::kInt);
|
||||
magmaLu<scalar_t>(
|
||||
n, n, self_data, n, piv_tmp.data<magma_int_t>(), info_tmp.data<magma_int_t>());
|
||||
pivots.copy_(piv_tmp);
|
||||
} else {
|
||||
magmaLuNoPiv<scalar_t>(n, n, self_data, n, info_tmp.data<magma_int_t>());
|
||||
}
|
||||
infos.copy_(info_tmp);
|
||||
} else {
|
||||
auto self_matrix_stride = matrixStride(self);
|
||||
magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
|
||||
|
||||
// Set up the created arrays
|
||||
for (int64_t i = 0; i < batch_size; i++) {
|
||||
self_array[i] = &self_data[i * self_matrix_stride];
|
||||
}
|
||||
scalar_t** self_array;
|
||||
ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
|
||||
|
||||
MAGMAQueue magma_queue(self.get_device());
|
||||
|
||||
// If `pivots` is defined, then we have to compute them.
|
||||
// We will use the normal getrf function to compute the LU factorization
|
||||
// and the pivots
|
||||
if (pivots.defined()) {
|
||||
auto pivots_data = pivots.data<magma_int_t>();
|
||||
auto pivots_matrix_stride = pivots.size(-1);
|
||||
magma_int_t** pivots_array;
|
||||
ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
|
||||
// Set up the created arrays
|
||||
for (int64_t i = 0; i < batch_size; i++) {
|
||||
pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
|
||||
self_array[i] = &self_data[i * self_matrix_stride];
|
||||
}
|
||||
|
||||
magmaGetrfBatched<scalar_t>(
|
||||
n, n, self_array, n, pivots_array,
|
||||
infos.data<magma_int_t>(), batch_size, magma_queue);
|
||||
} else {
|
||||
magmaGetrfNoPivBatched<scalar_t>(
|
||||
n, n, self_array, n, infos.data<magma_int_t>(),
|
||||
batch_size, magma_queue);
|
||||
MAGMAQueue magma_queue(self.get_device());
|
||||
|
||||
// Same comment as in the case of single matrix above.
|
||||
if (get_pivots) {
|
||||
auto pivots_data = pivots.data<magma_int_t>();
|
||||
auto pivots_matrix_stride = pivots.size(-1);
|
||||
magma_int_t** pivots_array;
|
||||
ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
|
||||
for (int64_t i = 0; i < batch_size; i++) {
|
||||
pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
|
||||
}
|
||||
magmaLuBatched<scalar_t>(
|
||||
n, n, self_array, n, pivots_array,
|
||||
infos.data<magma_int_t>(), batch_size, magma_queue);
|
||||
} else {
|
||||
magmaLuNoPivBatched<scalar_t>(
|
||||
n, n, self_array, n, infos.data<magma_int_t>(),
|
||||
batch_size, magma_queue);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cuda(const Tensor& self, bool pivot) {
|
||||
AT_CHECK(self.dim() > 2,
|
||||
"expected tensor with more than 2 dimensions, got size: ", self.sizes(),
|
||||
std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool pivot, bool check_errors) {
|
||||
AT_CHECK(self.dim() >= 2,
|
||||
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
|
||||
" instead");
|
||||
squareCheckInputs(self);
|
||||
auto req_size = self.sizes().vec();
|
||||
req_size.pop_back();
|
||||
Tensor pivots_tensor;
|
||||
if (pivot) {
|
||||
pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
|
||||
}
|
||||
Tensor pivots_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
|
||||
req_size.pop_back();
|
||||
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
|
||||
auto infos_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
|
||||
|
||||
Tensor self_working_copy;
|
||||
if (self.numel() == 0) {
|
||||
self_working_copy = at::empty_like(self);
|
||||
} else {
|
||||
self_working_copy = cloneBatchedColumnMajor(self);
|
||||
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cuda", [&]{
|
||||
apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
|
||||
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cuda", [&]{
|
||||
apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor, pivot);
|
||||
});
|
||||
}
|
||||
if (check_errors) {
|
||||
if (self.dim() == 2) {
|
||||
singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
|
||||
} else {
|
||||
batchCheckErrors(infos_tensor, "lu");
|
||||
}
|
||||
}
|
||||
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
|
||||
}
|
||||
|
||||
|
@ -3818,26 +3818,12 @@
|
||||
matches_jit_signature: True
|
||||
variants: method, function
|
||||
|
||||
- func: btrifact(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots) -> (Tensor(a!), Tensor(b!))
|
||||
matches_jit_signature: True
|
||||
|
||||
- func: btrifact(Tensor self, *, bool pivot=True) -> (Tensor, Tensor)
|
||||
matches_jit_signature: True
|
||||
variants: method, function
|
||||
|
||||
- func: btrifact_with_info(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!), Tensor(b!), Tensor(c!))
|
||||
matches_jit_signature: True
|
||||
|
||||
- func: btrifact_with_info(Tensor self, *, bool pivot=True) -> (Tensor, Tensor, Tensor)
|
||||
matches_jit_signature: True
|
||||
variants: method, function
|
||||
|
||||
- func: _btrifact_helper(Tensor self, bool pivot) -> (Tensor, Tensor, Tensor)
|
||||
- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
|
||||
matches_jit_signature: True
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _btrifact_helper_cpu
|
||||
CUDA: _btrifact_helper_cuda
|
||||
CPU: _lu_with_info_cpu
|
||||
CUDA: _lu_with_info_cuda
|
||||
|
||||
- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
|
@ -306,6 +306,7 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: long
|
||||
.. automethod:: lt
|
||||
.. automethod:: lt_
|
||||
.. automethod:: lu
|
||||
.. automethod:: map_
|
||||
.. automethod:: masked_scatter_
|
||||
.. automethod:: masked_scatter
|
||||
|
@ -315,6 +315,7 @@ BLAS and LAPACK Operations
|
||||
.. autofunction:: det
|
||||
.. autofunction:: logdet
|
||||
.. autofunction:: slogdet
|
||||
.. autofunction:: lu
|
||||
.. autofunction:: matmul
|
||||
.. autofunction:: matrix_power
|
||||
.. autofunction:: matrix_rank
|
||||
|
@ -2361,8 +2361,8 @@ class TestCuda(TestCase):
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
|
||||
def test_btrifact(self):
|
||||
_TestTorchMixin._test_btrifact(self, lambda t: t.cuda())
|
||||
def test_lu(self):
|
||||
_TestTorchMixin._test_lu(self, lambda t: t.cuda())
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
|
||||
|
@ -1746,44 +1746,43 @@ class _TestTorchMixin(object):
|
||||
_test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64))
|
||||
|
||||
@staticmethod
|
||||
def _test_btrifact(self, cast):
|
||||
def _test_lu(self, cast):
|
||||
from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
|
||||
|
||||
def run_test(matrix_size, batches, cast):
|
||||
a = cast(fullrank(matrix_size, *batches))
|
||||
a_LU_info, pivots_info, info_ = a.btrifact_with_info()
|
||||
a_LU_info, pivots_info, info_ = a.lu(get_infos=True)
|
||||
self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
|
||||
self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
|
||||
self.assertEqual(info_.size(), torch.Size(batches))
|
||||
self.assertEqual(info_.abs().sum(), 0)
|
||||
a_LU, pivots = a.btrifact()
|
||||
a_LU, pivots = a.lu()
|
||||
self.assertEqual(a_LU, a_LU_info)
|
||||
self.assertEqual(pivots_info, pivots)
|
||||
if a.is_cuda:
|
||||
a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False)
|
||||
self.assertIsNone(nopiv)
|
||||
a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True)
|
||||
self.assertEqual(nopiv, cast(torch.zeros(a.shape[:-1], dtype=torch.int32)))
|
||||
self.assertEqual(info_, info_nopiv)
|
||||
P, L, U = torch.btriunpack(a_LU, pivots)
|
||||
self.assertEqual(P.matmul(L.matmul(U)), a)
|
||||
|
||||
for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]):
|
||||
for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]):
|
||||
run_test(ms, batch, cast)
|
||||
|
||||
# Info should be positive for rank deficient matrices
|
||||
a = cast(fullrank(3, 5))
|
||||
a = cast(torch.ones(5, 3, 3))
|
||||
if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])):
|
||||
a[0, 1] = 2 * a[0, 0] # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency
|
||||
self.assertGreater(a.btrifact_with_info()[2][0], 0)
|
||||
self.assertGreater(a.lu(get_infos=True)[2][0], 0)
|
||||
|
||||
# Error checking, no pivoting variant on CPU
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'btrifact without pivoting is not implemented on the CPU'):
|
||||
torch.btrifact(torch.empty(1, 2, 2), pivot=False)
|
||||
'lu without pivoting is not implemented on the CPU'):
|
||||
torch.lu(torch.empty(1, 2, 2), pivot=False)
|
||||
|
||||
@skipIfNoLapack
|
||||
@skipIfRocm
|
||||
def test_btrifact(self):
|
||||
self._test_btrifact(self, lambda t: t)
|
||||
def test_lu(self):
|
||||
self._test_lu(self, lambda t: t)
|
||||
|
||||
@staticmethod
|
||||
def _test_btrisolve(self, cast):
|
||||
@ -1797,7 +1796,7 @@ class _TestTorchMixin(object):
|
||||
(-1.56, 4.00),
|
||||
(9.81, -4.09)))
|
||||
a, b = cast(a), cast(b)
|
||||
LU_data, pivots, info = a.btrifact_with_info()
|
||||
LU_data, pivots, info = a.lu(get_infos=True)
|
||||
self.assertEqual(info.abs().sum(), 0)
|
||||
x = torch.btrisolve(b, LU_data, pivots)
|
||||
b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
|
||||
@ -1811,12 +1810,11 @@ class _TestTorchMixin(object):
|
||||
def _test_btriunpack(self, cast):
|
||||
def run_test(shape, cast):
|
||||
a = cast(torch.randn(*shape))
|
||||
a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1]))
|
||||
a_lu = a_lu.reshape_as(a)
|
||||
p = p.reshape(a.shape[:-1])
|
||||
a_lu, p = torch.lu(a)
|
||||
p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
|
||||
self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
|
||||
|
||||
run_test((3, 3), cast)
|
||||
run_test((5, 3, 3), cast)
|
||||
run_test((7, 3, 5, 5), cast)
|
||||
run_test((7, 5, 3, 3, 3), cast)
|
||||
@ -4743,11 +4741,14 @@ class _TestTorchMixin(object):
|
||||
A = torch.randn(3, 3, device=A_device)
|
||||
err_str = "Expected b and A to be on the same device"
|
||||
with self.assertRaisesRegex(RuntimeError, err_str):
|
||||
torch.gesv(b, A)
|
||||
torch.solve(b, A)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_str):
|
||||
torch.cholesky_solve(b, A)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_str):
|
||||
torch.triangular_solve(b, A)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_qr(self):
|
||||
|
||||
@ -7965,12 +7966,12 @@ class _TestTorchMixin(object):
|
||||
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
|
||||
|
||||
if torch._C.has_lapack:
|
||||
# btrifact
|
||||
A_LU, pivots = fn(torch.btrifact, (0, 5, 5))
|
||||
# lu
|
||||
A_LU, pivots = fn(torch.lu, (0, 5, 5))
|
||||
self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
|
||||
A_LU, pivots = fn(torch.btrifact, (0, 0, 0))
|
||||
A_LU, pivots = fn(torch.lu, (0, 0, 0))
|
||||
self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
|
||||
A_LU, pivots = fn(torch.btrifact, (2, 0, 0))
|
||||
A_LU, pivots = fn(torch.lu, (2, 0, 0))
|
||||
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
|
||||
|
||||
@skipIfRocm
|
||||
|
@ -192,12 +192,6 @@
|
||||
self: grad.bmm(mat2.transpose(1, 2))
|
||||
mat2: self.transpose(1, 2).bmm(grad)
|
||||
|
||||
- name: btrifact(Tensor self, bool pivot)
|
||||
self: not_implemented("btrifact")
|
||||
|
||||
- name: btrifact_with_info(Tensor self, bool pivot)
|
||||
self: not_implemented("btrifact_with_info")
|
||||
|
||||
- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
|
||||
self: not_implemented("btrisolve")
|
||||
|
||||
@ -470,6 +464,9 @@
|
||||
self: zeros_like(self)
|
||||
other: zeros_like(other)
|
||||
|
||||
- name: _lu_with_info(Tensor self, bool pivot, bool check_errors)
|
||||
self: not_implemented("lu_with_info")
|
||||
|
||||
- name: masked_fill_(Tensor self, Tensor mask, Scalar value)
|
||||
self: grad.clone().masked_fill_(mask, 0)
|
||||
|
||||
|
@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [
|
||||
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
|
||||
'_th_.*', '_thnn_.*',
|
||||
'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
|
||||
'_cholesky.*', '_btrifact.*', '_triangular_solve.*',
|
||||
'_cholesky.*', '_triangular_solve.*',
|
||||
'slice', 'randint(_out)?',
|
||||
'item', '_local_scalar_dense',
|
||||
'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
|
||||
|
@ -78,6 +78,7 @@ class Tensor:
|
||||
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
|
||||
def split(self, split_size, dim=0): ...
|
||||
def unique(self, sorted=True, return_inverse=False, dim=None): ...
|
||||
def lu(self, pivot=True, get_infos=False): ...
|
||||
|
||||
${function_hints}
|
||||
|
||||
|
@ -486,20 +486,6 @@ bmm(batch2) -> Tensor
|
||||
See :func:`torch.bmm`
|
||||
""")
|
||||
|
||||
add_docstr_all('btrifact',
|
||||
r"""
|
||||
btrifact(pivot=True) -> (Tensor, Tensor)
|
||||
|
||||
See :func:`torch.btrifact`
|
||||
""")
|
||||
|
||||
add_docstr_all('btrifact_with_info',
|
||||
r"""
|
||||
btrifact_with_info(pivot=True) -> (Tensor, Tensor, Tensor)
|
||||
|
||||
See :func:`torch.btrifact_with_info`
|
||||
""")
|
||||
|
||||
add_docstr_all('btrisolve',
|
||||
r"""
|
||||
btrisolve(LU_data, LU_pivots) -> Tensor
|
||||
@ -1019,13 +1005,6 @@ ger(vec2) -> Tensor
|
||||
See :func:`torch.ger`
|
||||
""")
|
||||
|
||||
add_docstr_all('solve',
|
||||
r"""
|
||||
solve(A) -> Tensor, Tensor
|
||||
|
||||
See :func:`torch.solve`
|
||||
""")
|
||||
|
||||
add_docstr_all('indices',
|
||||
r"""
|
||||
indices() -> Tensor
|
||||
@ -2228,6 +2207,13 @@ Example::
|
||||
|
||||
""")
|
||||
|
||||
add_docstr_all('solve',
|
||||
r"""
|
||||
solve(A) -> Tensor, Tensor
|
||||
|
||||
See :func:`torch.solve`
|
||||
""")
|
||||
|
||||
add_docstr_all('sort',
|
||||
r"""
|
||||
sort(dim=-1, descending=False) -> (Tensor, LongTensor)
|
||||
|
@ -5516,71 +5516,6 @@ Example::
|
||||
[ 0., 0., 0.]])
|
||||
""".format(**factory_like_common_args))
|
||||
|
||||
add_docstr(torch.btrifact,
|
||||
r"""
|
||||
btrifact(A, pivot=True) -> (Tensor, IntTensor)
|
||||
|
||||
Batch LU factorization.
|
||||
|
||||
Returns a tuple containing the LU factorization and pivots. Pivoting is done if
|
||||
:attr:`pivot` is set.
|
||||
|
||||
.. note::
|
||||
LU factorization with :attr:`pivot` = ``True`` is not available for CPU, and attempting
|
||||
to do so will throw an error. However, LU factorization with :attr:`pivot` = ``True`` is
|
||||
available for CUDA.
|
||||
|
||||
Arguments:
|
||||
A (Tensor): the tensor to factor
|
||||
pivot (bool, optional): controls whether pivoting is done
|
||||
|
||||
Returns:
|
||||
A tuple containing factorization and pivots.
|
||||
|
||||
Example::
|
||||
|
||||
>>> A = torch.randn(2, 3, 3)
|
||||
>>> A_LU, pivots = torch.btrifact(A)
|
||||
>>> A_LU
|
||||
tensor([[[ 1.3506, 2.5558, -0.0816],
|
||||
[ 0.1684, 1.1551, 0.1940],
|
||||
[ 0.1193, 0.6189, -0.5497]],
|
||||
|
||||
[[ 0.4526, 1.2526, -0.3285],
|
||||
[-0.7988, 0.7175, -0.9701],
|
||||
[ 0.2634, -0.9255, -0.3459]]])
|
||||
|
||||
>>> pivots
|
||||
tensor([[ 3, 3, 3],
|
||||
[ 3, 3, 3]], dtype=torch.int32)
|
||||
""")
|
||||
|
||||
add_docstr(torch.btrifact_with_info,
|
||||
r"""
|
||||
btrifact_with_info(A, pivot=True) -> (Tensor, IntTensor, IntTensor)
|
||||
|
||||
Batch LU factorization with additional error information.
|
||||
|
||||
This is a version of :meth:`torch.btrifact` that always creates an info
|
||||
`IntTensor`, and returns it as the third return value.
|
||||
|
||||
Arguments:
|
||||
A (Tensor): the tensor to factor
|
||||
pivot (bool, optional): controls whether pivoting is done
|
||||
|
||||
Returns:
|
||||
A tuple containing factorization, pivots, and an `IntTensor` where non-zero
|
||||
values indicate whether factorization for each minibatch sample succeeds.
|
||||
|
||||
Example::
|
||||
|
||||
>>> A = torch.randn(2, 3, 3)
|
||||
>>> A_LU, pivots, info = A.btrifact_with_info()
|
||||
>>> if info.nonzero().size(0) == 0:
|
||||
>>> print('LU factorization succeeded for all samples!')
|
||||
LU factorization succeeded for all samples!
|
||||
""")
|
||||
|
||||
add_docstr(torch.btrisolve,
|
||||
r"""
|
||||
btrisolve(b, LU_data, LU_pivots) -> Tensor
|
||||
@ -5591,14 +5526,14 @@ Returns the LU solve of the linear system :math:`Ax = b`.
|
||||
|
||||
Arguments:
|
||||
b (Tensor): the RHS tensor
|
||||
LU_data (Tensor): the pivoted LU factorization of A from :meth:`btrifact`.
|
||||
LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
|
||||
LU_pivots (IntTensor): the pivots of the LU factorization
|
||||
|
||||
Example::
|
||||
|
||||
>>> A = torch.randn(2, 3, 3)
|
||||
>>> b = torch.randn(2, 3)
|
||||
>>> A_LU = torch.btrifact(A)
|
||||
>>> A_LU = torch.lu(A)
|
||||
>>> x = torch.btrisolve(b, *A_LU)
|
||||
>>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
|
||||
tensor(1.00000e-07 *
|
||||
|
@ -6,23 +6,26 @@ import warnings
|
||||
|
||||
__all__ = [
|
||||
'btriunpack',
|
||||
'broadcast_tensors',
|
||||
'btrifact',
|
||||
'btrifact_with_info',
|
||||
'cartesian_prod',
|
||||
'chain_matmul',
|
||||
'einsum',
|
||||
'broadcast_tensors',
|
||||
'gesv',
|
||||
'isfinite',
|
||||
'isinf',
|
||||
'lu',
|
||||
'norm',
|
||||
'meshgrid',
|
||||
'potrf',
|
||||
'pstrf',
|
||||
'potrs',
|
||||
'gesv',
|
||||
'split',
|
||||
'stft',
|
||||
'tensordot',
|
||||
'trtrs',
|
||||
'unique',
|
||||
'cartesian_prod',
|
||||
]
|
||||
|
||||
|
||||
@ -81,7 +84,7 @@ def split(tensor, split_size_or_sections, dim=0):
|
||||
|
||||
|
||||
def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
||||
r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor.
|
||||
r"""Unpacks the data and pivots from a LU factorization of a tensor.
|
||||
|
||||
Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
|
||||
|
||||
@ -94,7 +97,7 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
||||
Example::
|
||||
|
||||
>>> A = torch.randn(2, 3, 3)
|
||||
>>> A_LU, pivots = A.btrifact()
|
||||
>>> A_LU, pivots = A.lu()
|
||||
>>> P, A_L, A_U = torch.btriunpack(A_LU, pivots)
|
||||
>>>
|
||||
>>> # can recover A from factorization
|
||||
@ -111,13 +114,20 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
||||
L = U = None
|
||||
|
||||
if unpack_pivots:
|
||||
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
|
||||
LU_pivots = LU_pivots - 1
|
||||
for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
|
||||
LU_pivots_zero_idx = LU_pivots - 1
|
||||
if LU_data.dim() > 2:
|
||||
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
|
||||
for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
|
||||
final_order = list(range(sz))
|
||||
for k, j in enumerate(LU_pivots_zero_idx[idx]):
|
||||
final_order[k], final_order[j] = final_order[j], final_order[k]
|
||||
P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
|
||||
else:
|
||||
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype)
|
||||
final_order = list(range(sz))
|
||||
for k, j in enumerate(LU_pivots[idx]):
|
||||
for k, j, in enumerate(LU_pivots_zero_idx):
|
||||
final_order[k], final_order[j] = final_order[j], final_order[k]
|
||||
P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
|
||||
P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
|
||||
else:
|
||||
P = None
|
||||
|
||||
@ -751,6 +761,8 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
|
||||
In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
|
||||
with the default keyword arguments.
|
||||
|
||||
For more information regarding :func:`torch.trtrs`, please check :func:`torch.triangular_solve`.
|
||||
|
||||
.. warning::
|
||||
:func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
|
||||
removed in the next release. Please use :func:`torch.triangular_solve` instead.
|
||||
@ -758,3 +770,112 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
|
||||
warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
|
||||
"removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
|
||||
return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
|
||||
|
||||
|
||||
def btrifact(A, pivot=True, out=None):
|
||||
r"""Returns a tuple containing the LU factorization and pivots of :attr:`A`.
|
||||
Pivoting is done if :attr:`pivot` is set.
|
||||
|
||||
For more information regarding :func:`torch.btrifact`, please check :func:`torch.lu`.
|
||||
|
||||
.. warning::
|
||||
:func:`torch.btrifact` is deprecated in favour of :func:`torch.lu` and will be
|
||||
removed in the next release. Please use :func:`torch.lu` instead.
|
||||
"""
|
||||
warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be "
|
||||
"removed in the next release. Please use torch.lu instead.", stacklevel=2)
|
||||
return lu(A, pivot=pivot, get_infos=False, out=out)
|
||||
|
||||
|
||||
def btrifact_with_info(A, pivot=True, out=None):
|
||||
r"""Performs LU factorization and returns additional status information along with the LU
|
||||
factorization and pivots.
|
||||
|
||||
For more information regarding :func:`torch.btrifact_with_info`, please check :func:`torch.lu`.
|
||||
|
||||
.. warning::
|
||||
:func:`torch.btrifact_with_info` is deprecated in favour of :func:`torch.lu` and will
|
||||
be removed in the next release. Please use :func:`torch.lu` with the :attr:`get_infos`
|
||||
argument set to ``True`` instead.
|
||||
"""
|
||||
warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu and will be "
|
||||
"removed in the next release. Please use torch.lu with the get_infos argument "
|
||||
"set to True instead.",
|
||||
stacklevel=2)
|
||||
return lu(A, pivot=pivot, get_infos=True, out=out)
|
||||
|
||||
|
||||
def lu(A, pivot=True, get_infos=False, out=None):
|
||||
r"""Computes the LU factorization of a square matrix or batches of square matrices
|
||||
:attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
|
||||
Pivoting is done if :attr:`pivot` is set to ``True``.
|
||||
|
||||
.. note::
|
||||
The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,
|
||||
then the returned pivots is a tensor filled with zeros of the appropriate size.
|
||||
|
||||
.. note::
|
||||
LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting
|
||||
to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is
|
||||
available for CUDA.
|
||||
|
||||
.. note::
|
||||
This function does not check if the factorization was successful or not if
|
||||
:attr:`get_infos` is ``True`` since the status of the factorization is present in the
|
||||
third element of the return tuple.
|
||||
|
||||
Arguments:
|
||||
A (Tensor): the tensor to factor of size :math:`(*, m, m)`
|
||||
pivot (bool, optional): controls whether pivoting is done. Default: ``True``
|
||||
get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
|
||||
Default: ``False``
|
||||
out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
|
||||
then the elements in the tuple are Tensor, IntTensor,
|
||||
and IntTensor. If :attr:`get_infos` is ``False``, then the
|
||||
elements in the tuple are Tensor, IntTensor. Default: ``None``
|
||||
|
||||
Returns:
|
||||
(Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
|
||||
|
||||
- **factorization** (*Tensor*): the factorization of size :math:`(*, m, m)`
|
||||
|
||||
- **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`
|
||||
|
||||
- **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
|
||||
size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
|
||||
each minibatch has succeeded or failed
|
||||
|
||||
Example::
|
||||
|
||||
>>> A = torch.randn(2, 3, 3)
|
||||
>>> A_LU, pivots = torch.lu(A)
|
||||
>>> A_LU
|
||||
tensor([[[ 1.3506, 2.5558, -0.0816],
|
||||
[ 0.1684, 1.1551, 0.1940],
|
||||
[ 0.1193, 0.6189, -0.5497]],
|
||||
|
||||
[[ 0.4526, 1.2526, -0.3285],
|
||||
[-0.7988, 0.7175, -0.9701],
|
||||
[ 0.2634, -0.9255, -0.3459]]])
|
||||
>>> pivots
|
||||
tensor([[ 3, 3, 3],
|
||||
[ 3, 3, 3]], dtype=torch.int32)
|
||||
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
|
||||
>>> if info.nonzero().size(0) == 0:
|
||||
... print('LU factorization succeeded for all samples!')
|
||||
LU factorization succeeded for all samples!
|
||||
"""
|
||||
# If get_infos is True, then we don't need to check for errors and vice versa
|
||||
result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
|
||||
if out is not None:
|
||||
if not isinstance(out, (tuple, list)):
|
||||
raise TypeError("argument 'out' must be tuple of Tensors, not {}"
|
||||
.format(type(out).__name__))
|
||||
if len(out) - int(get_infos) != 2:
|
||||
raise TypeError("expected tuple of {} elements but got {}"
|
||||
.format(2 + int(get_infos), len(out)))
|
||||
return (out[i].resize_as_(result[i]).copy_(result[i]) for i in range(len(out)))
|
||||
if get_infos:
|
||||
return result # A_LU, pivots, infos
|
||||
else:
|
||||
return result[0], result[1] # A_LU, pivots
|
||||
|
@ -282,10 +282,33 @@ class Tensor(torch._C._TensorBase):
|
||||
def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
|
||||
r"""See :func:`torch.triangular_solve`"""
|
||||
warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
|
||||
"removed in the next release. Please use torch.triangular_solve.", stacklevel=2)
|
||||
"removed in the next release. Please use torch.triangular_solve instead.",
|
||||
stacklevel=2)
|
||||
return super(Tensor, self).triangular_solve(A, upper=upper,
|
||||
transpose=transpose, unitriangular=unitriangular)
|
||||
|
||||
def btrifact(self, pivot=True):
|
||||
r"""See :func:`torch.lu`"""
|
||||
warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be removed in "
|
||||
"the next release. Please use torch.lu instead.", stacklevel=2)
|
||||
return torch._lu_with_info(self, pivot=pivot, check_errors=True)
|
||||
|
||||
def btrifact_with_info(self, pivot=True):
|
||||
r"""See :func:`torch.lu`"""
|
||||
warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
|
||||
"and will be removed in the next release. Please use torch.lu with the "
|
||||
"get_infos argument set to True instead.", stacklevel=2)
|
||||
return torch._lu_with_info(self, pivot=pivot, check_errors=False)
|
||||
|
||||
def lu(self, pivot=True, get_infos=False):
|
||||
r"""See :func:`torch.lu`"""
|
||||
# If get_infos is True, then we don't need to check for errors and vice versa
|
||||
LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
|
||||
if get_infos:
|
||||
return LU, pivots, infos
|
||||
else:
|
||||
return LU, pivots
|
||||
|
||||
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
|
||||
center=True, pad_mode='reflect', normalized=False, onesided=True):
|
||||
r"""See :func:`torch.stft`
|
||||
|
Reference in New Issue
Block a user