mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] pytree output support for vmap
This commit is contained in:
@ -4,6 +4,7 @@ import collections
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
from .pytree_hacks import tree_map, tree_map_
|
||||
import gc
|
||||
|
||||
from .vmap import vmap
|
||||
@ -16,16 +17,6 @@ from functorch._C import (
|
||||
_grad_decrement_nesting,
|
||||
)
|
||||
|
||||
# TODO: replace this with tree_map from core
|
||||
def tree_map(fn, pytree):
|
||||
flat_args, spec = tree_flatten(pytree)
|
||||
return tree_unflatten([fn(arg) for arg in flat_args], spec)
|
||||
|
||||
def tree_map_(fn_, pytree):
|
||||
flat_args, _ = tree_flatten(pytree)
|
||||
[fn_(arg) for arg in flat_args]
|
||||
return pytree
|
||||
|
||||
# TODO: replace all of these with pytrees
|
||||
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
|
39
functorch/functorch/_src/pytree_hacks.py
Normal file
39
functorch/functorch/_src/pytree_hacks.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch.utils._pytree as _pytree
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
# TODO: The following function should only be used with vmap.
|
||||
# torch.return_types should be registered as PyTree nodes.
|
||||
# I can't figure out how to do that, so we are turning all of them
|
||||
# into normal Tuples for now (this is what vmap used to do anyways).
|
||||
# We probably want some special behavior for named tuples?
|
||||
def tree_flatten_hack(pytree):
|
||||
if _pytree._is_leaf(pytree) and not isinstance(pytree, tuple):
|
||||
return [pytree], _pytree.LeafSpec()
|
||||
|
||||
if isinstance(pytree, tuple):
|
||||
typ = tuple
|
||||
else:
|
||||
typ = type(pytree)
|
||||
|
||||
flatten_fn = _pytree.SUPPORTED_NODES[typ].flatten_fn
|
||||
child_pytrees, context = flatten_fn(pytree)
|
||||
|
||||
# Recursively flatten the children
|
||||
result : List[Any] = []
|
||||
children_specs : List['TreeSpec'] = []
|
||||
for child in child_pytrees:
|
||||
flat, child_spec = tree_flatten_hack(child)
|
||||
result += flat
|
||||
children_specs.append(child_spec)
|
||||
|
||||
return result, _pytree.TreeSpec(typ, context, children_specs)
|
||||
|
||||
# TODO: replace this with tree_map from core
|
||||
def tree_map(fn, pytree):
|
||||
flat_args, spec = tree_flatten(pytree)
|
||||
return tree_unflatten([fn(arg) for arg in flat_args], spec)
|
||||
|
||||
def tree_map_(fn_, pytree):
|
||||
flat_args, _ = tree_flatten(pytree)
|
||||
[fn_(arg) for arg in flat_args]
|
||||
return pytree
|
@ -3,6 +3,8 @@ import functools
|
||||
from torch import Tensor
|
||||
from typing import Any, Callable, Optional, Tuple, Union, List
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten
|
||||
from .pytree_hacks import tree_flatten_hack, tree_map_
|
||||
from functools import partial
|
||||
import warnings
|
||||
|
||||
from functorch._C import (
|
||||
@ -96,46 +98,54 @@ def _unwrap_batched(
|
||||
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
|
||||
out_dims: out_dims_t,
|
||||
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
|
||||
num_outputs = _num_outputs(batched_outputs)
|
||||
out_dims_as_tuple = _as_tuple(
|
||||
out_dims, num_outputs,
|
||||
lambda: f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must '
|
||||
f'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.')
|
||||
flat_batched_outputs, output_spec = tree_flatten_hack(batched_outputs)
|
||||
|
||||
# NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
# There is something wrong with our type bindings for functions that begin
|
||||
# with '_', see #40397.
|
||||
if isinstance(batched_outputs, Tensor):
|
||||
out_dim = out_dims_as_tuple[0]
|
||||
return _remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore
|
||||
return tuple(_remove_batch_dim(out, vmap_level, batch_size, out_dim) # type: ignore
|
||||
for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
|
||||
|
||||
# Checks that `fn` returned one or more Tensors and nothing else.
|
||||
# NB: A python function that return multiple arguments returns a single tuple,
|
||||
# so we are effectively checking that `outputs` is a single Tensor or a tuple of
|
||||
# Tensors.
|
||||
def _validate_outputs(outputs: Any, func: Callable) -> None:
|
||||
if isinstance(outputs, Tensor):
|
||||
return
|
||||
if not isinstance(outputs, tuple):
|
||||
raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
|
||||
f'Tensors, got type {type(outputs)} as the return.')
|
||||
for idx, output in enumerate(outputs):
|
||||
if isinstance(output, Tensor):
|
||||
for out in flat_batched_outputs:
|
||||
if isinstance(out, torch.Tensor):
|
||||
continue
|
||||
raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
|
||||
f'Tensors, got type {type(output)} for return {idx}.')
|
||||
f'Tensors, got type {type(out)} as a return.')
|
||||
|
||||
def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
|
||||
def incompatible_error():
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
|
||||
f'out_dims is not compatible with the structure of `outputs`. '
|
||||
f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '
|
||||
f'has structure {output_spec}.')
|
||||
|
||||
if isinstance(batched_outputs, torch.Tensor):
|
||||
# Some weird edge case requires us to spell out the following
|
||||
# see test_out_dims_edge_case
|
||||
if isinstance(out_dims, int):
|
||||
flat_out_dims = [out_dims]
|
||||
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
|
||||
flat_out_dims = out_dims
|
||||
out_dims = out_dims[0]
|
||||
else:
|
||||
incompatible_error()
|
||||
else:
|
||||
flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
|
||||
if flat_out_dims is None:
|
||||
incompatible_error()
|
||||
|
||||
flat_outputs = [
|
||||
_remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
|
||||
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
|
||||
]
|
||||
return tree_unflatten(flat_outputs, output_spec)
|
||||
|
||||
def _check_int(x, func, out_dims):
|
||||
if isinstance(x, int):
|
||||
return
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
|
||||
f'an int or a python collection of ints representing where in the outputs the '
|
||||
f'vmapped dimension should appear.')
|
||||
|
||||
def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
|
||||
if isinstance(out_dims, int):
|
||||
return
|
||||
if not isinstance(out_dims, tuple) or \
|
||||
not all([isinstance(out_dim, int) for out_dim in out_dims]):
|
||||
raise ValueError(
|
||||
f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
|
||||
f'an int or a tuple of int representing where in the outputs the '
|
||||
f'vmapped dimension should appear.')
|
||||
tree_map_(partial(_check_int, func=func, out_dims=out_dims), out_dims)
|
||||
|
||||
def _get_name(func: Callable):
|
||||
if hasattr(func, '__name__'):
|
||||
@ -250,13 +260,12 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||
def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args):
|
||||
_check_out_dims_is_int_or_int_tuple(out_dims, func)
|
||||
_check_out_dims_is_int_or_int_pytree(out_dims, func)
|
||||
vmap_level = _vmap_increment_nesting()
|
||||
torch._C._vmapmode_increment_nesting()
|
||||
try:
|
||||
batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
|
||||
batched_outputs = func(*batched_inputs)
|
||||
_validate_outputs(batched_outputs, func)
|
||||
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
|
||||
finally:
|
||||
torch._C._vmapmode_decrement_nesting()
|
||||
|
@ -27,13 +27,13 @@ class EnableVmapFallbackWarnings:
|
||||
|
||||
class TestVmapAPI(TestCase):
|
||||
def test_non_tensor_output_raises(self):
|
||||
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
|
||||
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as a return"):
|
||||
output = vmap(lambda x: 3.14)(torch.ones(3))
|
||||
|
||||
def multiple_outputs(x):
|
||||
return x, 3
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
|
||||
with self.assertRaisesRegex(ValueError, "got type <class 'int'> as a return"):
|
||||
vmap(multiple_outputs)(torch.ones(3))
|
||||
|
||||
def test_different_map_dim_size_raises(self):
|
||||
@ -90,7 +90,7 @@ class TestVmapAPI(TestCase):
|
||||
self.assertEqual(outputs[0], x * x)
|
||||
self.assertEqual(outputs[1], x * x * x)
|
||||
|
||||
def test_multiple_outputs_error_cases(self):
|
||||
def test_multiple_outputs(self):
|
||||
# This is the same thing as
|
||||
# def returns_tuple_of_tensors(x):
|
||||
# return x, x
|
||||
@ -107,13 +107,8 @@ class TestVmapAPI(TestCase):
|
||||
|
||||
# should not throw
|
||||
vmap(returns_tuple_of_tensors)(x)
|
||||
|
||||
# jax supports these, but we don't yet
|
||||
msg = "must only return Tensors, got type <class 'list'>"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(returns_list_of_two_tensors)(x)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(returns_list_of_one_tensor)(x)
|
||||
vmap(returns_list_of_two_tensors)(x)
|
||||
vmap(returns_list_of_one_tensor)(x)
|
||||
|
||||
def test_nested_with_same_map_dim(self):
|
||||
x = torch.randn(2, 3, 5)
|
||||
@ -267,8 +262,59 @@ class TestVmapAPI(TestCase):
|
||||
result = vmap(foo, out_dims=(1,))(tensor)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
|
||||
msg = '`out_dims` must be an int or a tuple of int'
|
||||
def test_pytree_returns(self):
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
def f(x):
|
||||
y = x.sin()
|
||||
return y, (y, y), [y, (y, y)]
|
||||
|
||||
y0, (y1, y2), (y3, (y4, y5)) = vmap(f)(x)
|
||||
self.assertEqual(y0, x.sin())
|
||||
self.assertEqual(y0, y1)
|
||||
self.assertEqual(y2, y1)
|
||||
self.assertEqual(y2, y3)
|
||||
self.assertEqual(y4, y3)
|
||||
self.assertEqual(y5, y4)
|
||||
|
||||
def test_pytree_returns_outdims(self):
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
def f(x):
|
||||
y = x.sin()
|
||||
return y, (y, y)
|
||||
|
||||
y0, (y1, y2) = vmap(f, out_dims=(0, (0, 1)))(x)
|
||||
self.assertEqual(y0, x.sin())
|
||||
self.assertEqual(y1, x.sin())
|
||||
self.assertEqual(y2, x.sin().t())
|
||||
|
||||
def test_pytree_returns_broadcast_simple(self):
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
def f(x):
|
||||
y = x.sin()
|
||||
return y, (y, y)
|
||||
|
||||
y0, (y1, y2) = vmap(f, out_dims=1)(x)
|
||||
self.assertEqual(y0, x.sin().t())
|
||||
self.assertEqual(y1, y0)
|
||||
self.assertEqual(y2, y0)
|
||||
|
||||
def test_pytree_returns_broadcast_nested(self):
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
def f(x):
|
||||
y = x.sin()
|
||||
return y, (y, y)
|
||||
|
||||
y0, (y1, y2) = vmap(f, out_dims=(0, 1))(x)
|
||||
self.assertEqual(y0, x.sin())
|
||||
self.assertEqual(y1, y0.t())
|
||||
self.assertEqual(y2, y0.t())
|
||||
|
||||
def test_out_dims_must_be_int_or_collection_of_int_err_msg(self):
|
||||
msg = 'must be an int or a python collection of ints'
|
||||
tensor = torch.randn(2, 3)
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
vmap(lambda x: x, out_dims='lol')(tensor)
|
||||
@ -280,7 +326,7 @@ class TestVmapAPI(TestCase):
|
||||
vmap(lambda x: x, out_dims=(None,))(tensor)
|
||||
|
||||
def test_out_dims_and_num_outputs_mismatch_err_msg(self):
|
||||
msg = '`out_dims` must have one dim per output'
|
||||
msg = 'not compatible'
|
||||
x = torch.randn(2, 3, 5)
|
||||
|
||||
# Too many out_dims
|
||||
@ -2639,9 +2685,9 @@ class TestVmapOperators(TestCase):
|
||||
self.assertEqual(loop_out, batched_out)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestVmapOperators, globals())
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestVmapOperators, globals(), only_for=only_for)
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestVmapBatchedGradient,
|
||||
globals(),
|
||||
|
Reference in New Issue
Block a user