mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] fixed dot batching rule with no batch dims
This commit is contained in:
@ -26,8 +26,11 @@ std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<i
|
||||
auto B_ = moveBatchDimToFront(B, B_bdim);
|
||||
if (A_bdim && B_bdim) {
|
||||
return {at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0};
|
||||
} else if (!A_bdim && !B_bdim) {
|
||||
return {at::dot(A_, B_), nullopt};
|
||||
} else {
|
||||
return {at::matmul(A_, B_.t()), 0};
|
||||
}
|
||||
return {at::matmul(A_, B_.t()), 0};
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user