mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] removed no bdim cases from batchruleslinearalgebra
This commit is contained in:
@ -10,14 +10,6 @@ namespace at { namespace functorch {
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
||||
slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
|
||||
if (!self_bdim.has_value()) {
|
||||
auto result = at::slogdet(self);
|
||||
return std::make_tuple(
|
||||
std::move(std::get<0>(result)), nullopt,
|
||||
std::move(std::get<1>(result)), nullopt
|
||||
);
|
||||
}
|
||||
|
||||
// slogdet supports arbitrary dims at the front
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto result = at::slogdet(self_);
|
||||
@ -51,10 +43,7 @@ std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<i
|
||||
static std::tuple<Tensor, optional<int64_t>> tv_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim,
|
||||
const Tensor& other, optional<int64_t> other_bdim) {
|
||||
if (!self_bdim && !other_bdim) {
|
||||
return std::make_tuple( at::matmul(self, other), nullopt );
|
||||
}
|
||||
else if (self_bdim && other_bdim) {
|
||||
if (self_bdim && other_bdim) {
|
||||
// See Note [Batching rules for matmul-like operators]
|
||||
// B...OI, BI -> ...BOI, BI1 -> ...BO1 -> ...BO
|
||||
auto self_ = at::movedim(self, *self_bdim, -3);
|
||||
@ -81,9 +70,6 @@ static std::tuple<Tensor, optional<int64_t>> tv_batch_rule(
|
||||
static std::tuple<Tensor, optional<int64_t>> mv_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim,
|
||||
const Tensor& other, optional<int64_t> other_bdim) {
|
||||
if (!self_bdim && !other_bdim) {
|
||||
return std::make_tuple( at::mv(self, other), nullopt );
|
||||
}
|
||||
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
|
||||
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
|
||||
TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 1,
|
||||
@ -97,9 +83,6 @@ static std::tuple<Tensor, optional<int64_t>> mv_batch_rule(
|
||||
static std::tuple<Tensor, optional<int64_t>> mm_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim,
|
||||
const Tensor& other, optional<int64_t> other_bdim) {
|
||||
if (!self_bdim && !other_bdim) {
|
||||
return std::make_tuple( at::matmul(self, other), nullopt );
|
||||
}
|
||||
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
|
||||
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
|
||||
TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 2,
|
||||
@ -115,9 +98,6 @@ static std::tuple<Tensor, optional<int64_t>> mm_batch_rule(
|
||||
static std::tuple<Tensor, optional<int64_t>> bmm_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim,
|
||||
const Tensor& other, optional<int64_t> other_bdim) {
|
||||
if (!self_bdim && !other_bdim) {
|
||||
return std::make_tuple( at::bmm(self, other), nullopt );
|
||||
}
|
||||
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
|
||||
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
|
||||
TORCH_CHECK(self_logical_rank == 3 && other_logical_rank == 3,
|
||||
@ -146,11 +126,7 @@ std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> linalg_eigh_batch_
|
||||
const Tensor& self, optional<int64_t> self_bdim, c10::string_view UPLO) {
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto result = at::linalg_eigh(self_, UPLO);
|
||||
optional<int64_t> result_bdim;
|
||||
if (self_bdim) {
|
||||
result_bdim = 0;
|
||||
}
|
||||
return std::make_tuple(std::get<0>(result), result_bdim, std::get<1>(result), result_bdim);
|
||||
return std::make_tuple(std::get<0>(result), 0, std::get<1>(result), 0);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
|
Reference in New Issue
Block a user