[functorch] fixed dot batching rule with no batch dims

This commit is contained in:
Horace He
2021-04-29 00:58:35 -07:00
committed by Jon Janzen
parent 0d94ae66a7
commit c521cd1f45

View File

@ -26,8 +26,11 @@ std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<i
auto B_ = moveBatchDimToFront(B, B_bdim);
if (A_bdim && B_bdim) {
return {at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0};
} else if (!A_bdim && !B_bdim) {
return {at::dot(A_, B_), nullopt};
} else {
return {at::matmul(A_, B_.t()), 0};
}
return {at::matmul(A_, B_.t()), 0};
}