mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Switched over to pytorch core's tree_map
This commit is contained in:
@ -3,8 +3,8 @@ from functools import partial, wraps
|
||||
import collections
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
from .pytree_hacks import tree_map, tree_map_, treespec_pprint
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
||||
from .pytree_hacks import tree_map_, treespec_pprint
|
||||
import gc
|
||||
|
||||
from .vmap import vmap
|
||||
|
@ -28,11 +28,6 @@ def tree_flatten_hack(pytree):
|
||||
|
||||
return result, _pytree.TreeSpec(typ, context, children_specs)
|
||||
|
||||
# TODO: replace this with tree_map from core
|
||||
def tree_map(fn, pytree):
|
||||
flat_args, spec = tree_flatten(pytree)
|
||||
return tree_unflatten([fn(arg) for arg in flat_args], spec)
|
||||
|
||||
def tree_map_(fn_, pytree):
|
||||
flat_args, _ = tree_flatten(pytree)
|
||||
[fn_(arg) for arg in flat_args]
|
||||
|
@ -117,7 +117,7 @@ def _unwrap_batched(
|
||||
# Some weird edge case requires us to spell out the following
|
||||
# see test_out_dims_edge_case
|
||||
if isinstance(out_dims, int):
|
||||
flat_out_dims = [out_dims]
|
||||
flat_out_dims = [out_dims]
|
||||
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
|
||||
flat_out_dims = out_dims
|
||||
out_dims = out_dims[0]
|
||||
|
@ -2737,7 +2737,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
return
|
||||
|
||||
# entries in here need don't work and need to be fixed.
|
||||
vmap_fail = {'repeat'}
|
||||
vmap_fail = {'repeat', 'ravel', 'clamp'}
|
||||
if op.name in vmap_fail:
|
||||
return
|
||||
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
|
Reference in New Issue
Block a user