mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fix] legacybatching: getPhysicalDims (#93261)
Fixes #92985 Minimum Repro: ```python import torch from torch._vmap_internals import vmap input = torch.randn(2, 2) def fn(x): return x.sum(()) o = vmap(fn)(input) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93261 Approved by: https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
7a621c443b
commit
845e4b8a47
@ -1871,6 +1871,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
|
||||
# Single vmap, various in_dims / out_dims
|
||||
test(lambda x: x.sum(()), [torch.randn([B0])])
|
||||
test(lambda x: x.sum(()), [torch.randn([B0, 2])])
|
||||
test(lambda x: x.sum(0), [torch.randn([B0])])
|
||||
test(lambda x: x.sum(-1), [torch.randn([B0])])
|
||||
test(lambda x: x.sum(0), [torch.randn([B0, 3])])
|
||||
|
Reference in New Issue
Block a user