[functorch] Switched over to pytorch core's tree_map

This commit is contained in:
Horace He
2021-05-10 10:41:58 -07:00
committed by Jon Janzen
parent 66081519da
commit 256099c1a6
4 changed files with 4 additions and 9 deletions

View File

@ -3,8 +3,8 @@ from functools import partial, wraps
import collections
import torch.nn as nn
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten, tree_unflatten
from .pytree_hacks import tree_map, tree_map_, treespec_pprint
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from .pytree_hacks import tree_map_, treespec_pprint
import gc
from .vmap import vmap

View File

@ -28,11 +28,6 @@ def tree_flatten_hack(pytree):
return result, _pytree.TreeSpec(typ, context, children_specs)
# TODO: replace this with tree_map from core
def tree_map(fn, pytree):
flat_args, spec = tree_flatten(pytree)
return tree_unflatten([fn(arg) for arg in flat_args], spec)
def tree_map_(fn_, pytree):
flat_args, _ = tree_flatten(pytree)
[fn_(arg) for arg in flat_args]

View File

@ -117,7 +117,7 @@ def _unwrap_batched(
# Some weird edge case requires us to spell out the following
# see test_out_dims_edge_case
if isinstance(out_dims, int):
flat_out_dims = [out_dims]
flat_out_dims = [out_dims]
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
flat_out_dims = out_dims
out_dims = out_dims[0]

View File

@ -2737,7 +2737,7 @@ class TestVmapOperatorsOpInfo(TestCase):
return
# entries in here need don't work and need to be fixed.
vmap_fail = {'repeat'}
vmap_fail = {'repeat', 'ravel', 'clamp'}
if op.name in vmap_fail:
return
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)