[functorch] fix vmap tests and added dot shape check

This commit is contained in:
Horace He
2021-05-03 00:39:42 -07:00
committed by Jon Janzen
parent 73f57a7192
commit 86e49cf0d7
2 changed files with 3 additions and 2 deletions

View File

@ -22,6 +22,7 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
}
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
TORCH_CHECK(A.dim() - A_bdim.has_value() == 1 && B.dim() - B_bdim.has_value() == 1, "Got wrong shapes for dot");
auto A_ = moveBatchDimToFront(A, A_bdim);
auto B_ = moveBatchDimToFront(B, B_bdim);
if (A_bdim && B_bdim) {

View File

@ -2635,7 +2635,7 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase):
result = vmap(vjp)(gy)
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
class TestVmapOperators(TestCase):
class TestVmapOperatorsOpInfo(TestCase):
@onlyCPU
@ops(op_db, allowed_dtypes=(torch.float,))
def test_normalize_operator_exhaustive(self, device, dtype, op):
@ -2686,7 +2686,7 @@ class TestVmapOperators(TestCase):
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestVmapOperators, globals(), only_for=only_for)
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
instantiate_device_type_tests(
TestVmapBatchedGradient,