[fix] legacybatching: getPhysicalDims (#93261)

Fixes #92985

Minimum Repro:
```python
import torch
from torch._vmap_internals import vmap

input = torch.randn(2, 2)

def fn(x):
    return x.sum(())

o = vmap(fn)(input)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93261
Approved by: https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
Kshiteej K
2023-01-30 21:06:32 +00:00
committed by PyTorch MergeBot
parent 7a621c443b
commit 845e4b8a47
2 changed files with 2 additions and 1 deletions

View File

@ -1871,6 +1871,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
# Single vmap, various in_dims / out_dims
test(lambda x: x.sum(()), [torch.randn([B0])])
test(lambda x: x.sum(()), [torch.randn([B0, 2])])
test(lambda x: x.sum(0), [torch.randn([B0])])
test(lambda x: x.sum(-1), [torch.randn([B0])])
test(lambda x: x.sum(0), [torch.randn([B0, 3])])