[functorch] removed no bdim cases from batchruleslinearalgebra

This commit is contained in:
Horace He
2021-06-28 14:55:41 -07:00
committed by Jon Janzen
parent 8cd80e0b16
commit cb286b9b49

View File

@ -10,14 +10,6 @@ namespace at { namespace functorch {
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
if (!self_bdim.has_value()) {
auto result = at::slogdet(self);
return std::make_tuple(
std::move(std::get<0>(result)), nullopt,
std::move(std::get<1>(result)), nullopt
);
}
// slogdet supports arbitrary dims at the front
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::slogdet(self_);
@ -51,10 +43,7 @@ std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<i
static std::tuple<Tensor, optional<int64_t>> tv_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim) {
if (!self_bdim && !other_bdim) {
return std::make_tuple( at::matmul(self, other), nullopt );
}
else if (self_bdim && other_bdim) {
if (self_bdim && other_bdim) {
// See Note [Batching rules for matmul-like operators]
// B...OI, BI -> ...BOI, BI1 -> ...BO1 -> ...BO
auto self_ = at::movedim(self, *self_bdim, -3);
@ -81,9 +70,6 @@ static std::tuple<Tensor, optional<int64_t>> tv_batch_rule(
static std::tuple<Tensor, optional<int64_t>> mv_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim) {
if (!self_bdim && !other_bdim) {
return std::make_tuple( at::mv(self, other), nullopt );
}
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 1,
@ -97,9 +83,6 @@ static std::tuple<Tensor, optional<int64_t>> mv_batch_rule(
static std::tuple<Tensor, optional<int64_t>> mm_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim) {
if (!self_bdim && !other_bdim) {
return std::make_tuple( at::matmul(self, other), nullopt );
}
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 2,
@ -115,9 +98,6 @@ static std::tuple<Tensor, optional<int64_t>> mm_batch_rule(
static std::tuple<Tensor, optional<int64_t>> bmm_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim) {
if (!self_bdim && !other_bdim) {
return std::make_tuple( at::bmm(self, other), nullopt );
}
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
TORCH_CHECK(self_logical_rank == 3 && other_logical_rank == 3,
@ -146,11 +126,7 @@ std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> linalg_eigh_batch_
const Tensor& self, optional<int64_t> self_bdim, c10::string_view UPLO) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::linalg_eigh(self_, UPLO);
optional<int64_t> result_bdim;
if (self_bdim) {
result_bdim = 0;
}
return std::make_tuple(std::get<0>(result), result_bdim, std::get<1>(result), result_bdim);
return std::make_tuple(std::get<0>(result), 0, std::get<1>(result), 0);
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {