mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] fix vmap tests and added dot shape check
This commit is contained in:
@ -22,6 +22,7 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
|
||||
}
|
||||
|
||||
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
|
||||
TORCH_CHECK(A.dim() - A_bdim.has_value() == 1 && B.dim() - B_bdim.has_value() == 1, "Got wrong shapes for dot");
|
||||
auto A_ = moveBatchDimToFront(A, A_bdim);
|
||||
auto B_ = moveBatchDimToFront(B, B_bdim);
|
||||
if (A_bdim && B_bdim) {
|
||||
|
@ -2635,7 +2635,7 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase):
|
||||
result = vmap(vjp)(gy)
|
||||
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
|
||||
|
||||
class TestVmapOperators(TestCase):
|
||||
class TestVmapOperatorsOpInfo(TestCase):
|
||||
@onlyCPU
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
def test_normalize_operator_exhaustive(self, device, dtype, op):
|
||||
@ -2686,7 +2686,7 @@ class TestVmapOperators(TestCase):
|
||||
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestVmapOperators, globals(), only_for=only_for)
|
||||
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestVmapBatchedGradient,
|
||||
|
Reference in New Issue
Block a user