mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow None to pass through for vmap (#65565)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65565 Does jax allow this? Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D31236258 Pulled By: soulitzer fbshipit-source-id: 80460b355fc32ecbba8151e1f3179f076a927f9d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
89ed9bdaee
commit
b6d5f1ee70
@ -87,8 +87,8 @@ def _create_batched_inputs(
|
||||
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
|
||||
def _unwrap_batched(
|
||||
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
|
||||
out_dims: out_dims_t,
|
||||
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
|
||||
out_dims: out_dims_t, vmap_level: int, batch_size: int, func: Callable,
|
||||
allow_none_pass_through: bool = False) -> Tuple:
|
||||
num_outputs = _num_outputs(batched_outputs)
|
||||
out_dims_as_tuple = _as_tuple(
|
||||
out_dims, num_outputs,
|
||||
@ -101,8 +101,12 @@ def _unwrap_batched(
|
||||
if isinstance(batched_outputs, Tensor):
|
||||
out_dim = out_dims_as_tuple[0]
|
||||
return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
|
||||
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
|
||||
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
|
||||
if allow_none_pass_through:
|
||||
return tuple((torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) if out is not None else None)
|
||||
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
|
||||
else:
|
||||
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
|
||||
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
|
||||
|
||||
# Checks that `fn` returned one or more Tensors and nothing else.
|
||||
# NB: A python function that return multiple arguments returns a single tuple,
|
||||
@ -253,7 +257,11 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||
return _vmap(func, in_dims, out_dims)
|
||||
|
||||
# A version of vmap but without the initial "experimental prototype" warning
|
||||
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
|
||||
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0, allow_none_pass_through: bool = False) -> Callable:
|
||||
# The `allow_none_pass_through` argument is a temporary workaround may be removed.
|
||||
# Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
|
||||
# which may return None if any of the inputs are unused. See the issue discussing this:
|
||||
# https://github.com/facebookresearch/functorch/issues/159.
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args):
|
||||
_check_out_dims_is_int_or_int_tuple(out_dims, func)
|
||||
@ -261,8 +269,10 @@ def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> C
|
||||
try:
|
||||
batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
|
||||
batched_outputs = func(*batched_inputs)
|
||||
_validate_outputs(batched_outputs, func)
|
||||
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
|
||||
if not allow_none_pass_through:
|
||||
_validate_outputs(batched_outputs, func)
|
||||
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func,
|
||||
allow_none_pass_through=allow_none_pass_through)
|
||||
finally:
|
||||
torch._C._vmapmode_decrement_nesting()
|
||||
return wrapped
|
||||
|
Reference in New Issue
Block a user