Allow None to pass through for vmap (#65565)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65565

Does jax allow this?

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D31236258

Pulled By: soulitzer

fbshipit-source-id: 80460b355fc32ecbba8151e1f3179f076a927f9d
This commit is contained in:
soulitzer
2021-10-03 19:52:30 -07:00
committed by Facebook GitHub Bot
parent 89ed9bdaee
commit b6d5f1ee70

View File

@ -87,8 +87,8 @@ def _create_batched_inputs(
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
out_dims: out_dims_t, vmap_level: int, batch_size: int, func: Callable,
allow_none_pass_through: bool = False) -> Tuple:
num_outputs = _num_outputs(batched_outputs)
out_dims_as_tuple = _as_tuple(
out_dims, num_outputs,
@ -101,8 +101,12 @@ def _unwrap_batched(
if isinstance(batched_outputs, Tensor):
out_dim = out_dims_as_tuple[0]
return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
if allow_none_pass_through:
return tuple((torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) if out is not None else None)
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
else:
return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
# Checks that `fn` returned one or more Tensors and nothing else.
# NB: A python function that return multiple arguments returns a single tuple,
@ -253,7 +257,11 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
return _vmap(func, in_dims, out_dims)
# A version of vmap but without the initial "experimental prototype" warning
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0, allow_none_pass_through: bool = False) -> Callable:
# The `allow_none_pass_through` argument is a temporary workaround may be removed.
# Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
# which may return None if any of the inputs are unused. See the issue discussing this:
# https://github.com/facebookresearch/functorch/issues/159.
@functools.wraps(func)
def wrapped(*args):
_check_out_dims_is_int_or_int_tuple(out_dims, func)
@ -261,8 +269,10 @@ def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> C
try:
batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
batched_outputs = func(*batched_inputs)
_validate_outputs(batched_outputs, func)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
if not allow_none_pass_through:
_validate_outputs(batched_outputs, func)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func,
allow_none_pass_through=allow_none_pass_through)
finally:
torch._C._vmapmode_decrement_nesting()
return wrapped