73 Commits

Author SHA1 Message Date
ff35e1e45b [pytree] Add custom treespec fqn field (#112428)
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.

Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
2023-11-02 00:26:41 +00:00
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
449271f3f1 [pytree] Extract reusable generic tests for pytree (#110395)
Part of #109684

- #109684

Changes:

- Add new functions `tree_structure`, `tree_leaves`, `tree_map_` and `tree_map_only_` to Python pytree.
- Extract reusable tests for pytree to `TestGenericPytree`.
- Change `treespec_dumps` and `treespec_loads` in C++ pytree to call Python pytree and use JSON string as serialization type.
- Rename `torch.utils.pytree` -> `torch.utils._cxx_pytree`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110395
Approved by: https://github.com/zou3519
2023-10-04 23:40:50 +00:00
518308a740 Trace through pytree API with dynamo. (#108533)
Fix: #107315

This PR enables dynamo to trace through the `pytree` API by inlining its functions. In
order to do so, a few details of `pytree` had to be changed.

In summary, this PR:

- Introduces `TreeSpecVariable` for representing `TreeSpec` instances
- Specializes `<type>.__bases__` call, returning a `TupleVariable`
- Enables the call to `id` builtin function for every variable that implements
  `as_python_constant` method
- Specializes `ConstantVariable.call_method` for its (un)flatten functions
- Implements `UserDefinedObjectVariable.as_python_constant`
- Modifies `pytree` by:
    - Make `SUPPORTED_NODES` a map of ids (instead of types) to `NodeDef`
    - Removed `functools.wraps` function, since it can't be inlined

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
ghstack dependencies: #109201
2023-09-20 00:04:56 +00:00
0bf30c140a [pytree] Use OpTree for PyTree manipulation (#93139)
Split from #92679. Use C++-based PyTree implementation.

## Highlights

1. High performance (20x speedup than the pure-Python implementation, 10%-20% overall speedup for `torch.fx`)
2. Multi-input tree-map support
3. Custom tree node registry with namespace isolation

Refs:

- #65761
- #91323
- #92679

From https://github.com/pytorch/pytorch/issues/65761#issuecomment-1334746366:

> ### 0. Out-of-box compatible with JAX's pytree, provides the same interfaces and functions (and more).
>
> ### 1. High-performance: `optree` has comparable fast tree operations (~0.9x for `dict`s and ~2.5x for `OrderedDict`s) than JAX's pytree and it is 20x faster than `torch.utils._pytree`.
>
> `optree` implements some common Python container types in C++ (e.g., `OrderedDict`) and achieves 2.5x performance than JAX's pytree. Check out section [Built-in PyTree Node Types](https://github.com/metaopt/optree#built-in-pytree-node-types) and [Benchmark](https://github.com/metaopt/optree#benchmark) for more details.
>
> | Module    | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) |
> | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: |
> | TinyMLP   |    53 |       26.40 |        68.19 |       586.87 |        34.14 |            2.58 |           22.23 |            1.29 |
> | AlexNet   |   188 |       84.28 |       259.51 |      2182.07 |       125.12 |            3.08 |           25.89 |            1.48 |
> | ResNet18  |   698 |      288.57 |       807.27 |      7881.69 |       429.39 |            2.80 |           27.31 |            1.49 |
> | ResNet34  |  1242 |      580.75 |      1564.97 |     15082.84 |       819.02 |            2.69 |           25.97 |            1.41 |
> | ResNet50  |  1702 |      791.18 |      2081.17 |     20982.82 |      1104.62 |            2.63 |           26.52 |            1.40 |
> | ResNet101 |  3317 |     1603.93 |      3939.37 |     40382.14 |      2208.63 |            2.46 |           25.18 |            1.38 |
> | ResNet152 |  4932 |     2446.56 |      6267.98 |     56892.36 |      3139.17 |            2.56 |           23.25 |            1.28 |
> | ViT-H/14  |  3420 |     1681.48 |      4488.33 |     41703.16 |      2504.86 |            2.67 |           24.80 |            1.49 |
> | Swin-B    |  2881 |     1565.41 |      4091.10 |     34241.99 |      1936.75 |            2.61 |           21.87 |            1.24 |
> |           |       |             |              |              |  **Average** |        **2.68** |       **24.78** |        **1.38** |
>
> <div align="center">
>   <img src="https://user-images.githubusercontent.com/16078332/200494435-fd5bb385-59f7-4811-b520-98bf5763ccf3.png" width="90%" />
> </div>
>
> ### 2. Namespace Isolation for the PyTree Type Registry
>
> In addition to the JAX's pytree registry for custom node type registration, `optree` adds `namespace` isolation to the registry. Users can register the same type multiple times for different flatten/unflatten behavior. It also provides module-level isolation for safety reasons. For example, you can add a unique prefix to your namespace to isolate your registry with other modules (e.g., `torch.xxx`, `torch.functorch.xxx`):
>
> ```python
> # Register a Python type into a namespace
> import torch
>
> optree.register_pytree_node(
>     torch.Tensor,
>     # (tensor) -> (children, metadata)
>     flatten_func=lambda tensor: (
>         (tensor.cpu().numpy(),),
>         dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
>     ),
>     # (metadata, children) -> tensor
>     unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
>     namespace='torch.torch2numpy',
> )
> ```
>
> ```python
> >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
> >>> tree
> {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>
> # Flatten without specifying the namespace
> >>> tree_flatten(tree)  # `torch.Tensor`s are leaf nodes
> ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>
> # Flatten with the namespace
> >>> leaves, treespec = optree.tree_flatten(tree, namespace='torch.torch2numpy')
> >>> leaves, treespec
> (
>     [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
>     PyTreeSpec(
>         {
>             'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
>             'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
>         },
>         namespace='torch.torch2numpy'
>     )
> )
>
> # `entries` are not defined and use `range(len(children))`
> >>> optree.tree_paths(tree, namespace='torch.torch2numpy')
> [('bias', 0), ('weight', 0)]
>
> # Unflatten back to a copy of the original object
> >>> optree.tree_unflatten(treespec, leaves)
> {'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')}
> ```
>
> Check out section [Registering a Container-like Custom Type as Non-leaf Nodes](https://github.com/metaopt/optree#notes-about-the-pytree-type-registry) for more details.
>
> ### 3. Support both `None` as Non-leaf Node and `None` as Leaf
>
> In JAX's implementation, `None` is always an internal non-leaf node with an arity 0, which is like an empty tuple. This limits the usage of the JAX's pytree utilities for PyTorch. For example, the `nn.Module` uses `_parameters` and `_buffers` (`OrderedDict[str, Optional[Tensor]]`) to hold the tensors, while the value can be a tensor or `None`.
>
> `optree` supports both `None` as Non-leaf Node (JAX's default) and `None` as Leaf (PyTorch's default). Check out section [None is Non-leaf Node vs. None is Leaf](https://github.com/metaopt/optree#none-is-non-leaf-node-vs-none-is-leaf) for more details.
>
> ### 4. Some other improvements and bug fixes
>
> 1. Adds in-place version of treemap (`tree_map_`), which reduces redundant unflatten operation for better performance.
> 2. Adds support for tree flatten and tree map with paths. (useful for `functorch` module extraction).
> 3. Improves the JAX's pytree sorting support for `dict`s.
> 4. Better string representation `repr(PyTreeSpec)`.
> 5. Fixes some bugs for JAX's pytree of hashing, pickle serialization, segmentation fault for infinite recursion, and tree-compose/tree-transpose.

From https://github.com/pytorch/pytorch/pull/92679#issuecomment-1398778481:

> ```python
> # pytree_make_fx_bench.py
> import torch
> from torch.fx.experimental.proxy_tensor import make_fx
> import time
>
> def f(x):
>     for _ in range(10000):
>         x = x+x
>     return x
>
> import time
> begin = time.time()
> out = make_fx(f, tracing_mode="real")(torch.randn(20))
> begin = time.time()
> print(f'tracing_mode="real" {time.time() - begin:.2f}')
> out = make_fx(f, tracing_mode="fake")(torch.randn(20))
> print(f'tracing_mode="fake" {time.time() - begin:.2f}')
>
> out = make_fx(f, tracing_mode="symbolic")(torch.randn(20))
> print(f'tracing_mode="symbolic" {time.time() - begin:.2f}')
> ```
>
> This seems to run around 10-20% faster with the optree implementation:
>
> ```
> # Optree
> python pytree_make_fx_bench.py
> tracing_mode="real" 0.00
> tracing_mode="fake" 6.32
> tracing_mode="symbolic" 27.13
> ```
>
> ```
> # torch.utils._pytree
> python pytree_make_fx_bench.py
> tracing_mode="real" 0.00
> tracing_mode="fake" 7.66
> tracing_mode="symbolic" 31.07
> ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93139
Approved by: https://github.com/malfet
2023-09-18 21:24:56 +00:00
a432f37e49 Serialize pytree to json string (#106116)
Fixes https://github.com/pytorch/pytorch/pull/102577#issuecomment-1650905536

Serializing to json is more stable, and renamed the API:

```
# Takes in a treespec and returns the serialized treespec as a string. Also optionally takes in a protocol version number.
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
# Takes in a serialized treespec and outputs a TreeSpec
def treespec_loads(data: str) -> TreeSpec:
```

If users want to register their own serialization format for a given pytree, they can go through the `_register_treespec_serializer` API which optionally takes in a `getstate` and `setstate` function.
```
_register_treespec_serializer(type_, *, getstate, setstate)
# Takes in the context, and outputs a json-dumpable context
def getstate(context: Context) -> DumpableContext:
# Takes in a json-dumpable context, and reconstructs the original context
def setstate(dumpable_context: DumpableContext) -> Context:
```

We will serialize to the following dataclass, and then json.dump this it to string.
```
class TreeSpec
    type: Optional[str]  # a string name of the type. null for the case of a LeafSpec
    context: Optional[Any]  # optional, a json dumpable format of the context
    children_specs: List[TreeSpec],
}
```

If no getstate/setstate function is registered, we will by default serialize the context using `json.dumps/loads`. We will also serialize the type through `f"{typ.__module__}.{typ.__name__}"`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106116
Approved by: https://github.com/zou3519
2023-08-27 14:34:49 +00:00
3a7d77f704 Serialize empty pytree cases (#105159)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105159
Approved by: https://github.com/zhxchen17
2023-07-13 23:02:59 +00:00
ec24f1e4cc Simulate treespec flattening/unflattening (#101896)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101896
Approved by: https://github.com/jansel, https://github.com/anijain2305
2023-06-23 10:53:15 +00:00
c0596ffe85 improve repr for pytrees (#103945)
The current thing indents based on the length of the previous line, which is totally unreadable if, e.g. the treespec is a dict with a lot of keys, since all the keys will go on a ginormous line and everything after will be super indented.

Fix the indentation at 2, which is much more compact.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103945
Approved by: https://github.com/zou3519
2023-06-21 20:53:03 +00:00
bd0a4e2d83 Serialize pytree to string v2 (#102708)
v2 of https://github.com/pytorch/pytorch/pull/102577
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102708
Approved by: https://github.com/avikchaudhuri
2023-06-01 19:51:28 +00:00
09a967d6c9 Make nested TreeSpec printing nicer (#46538) (#86546)
1. Made TreeSpec into a dataclass.
2. In `__repr__`, recursively transformed TreeSpec into dictionaries and then pretty-printed it.

Fixes #46538. Hi, @ezyang. this PR is for the TreeSpec `__repr__` refactor we discussed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86546
Approved by: https://github.com/ezyang
2022-10-18 16:50:39 +00:00
b8b54eccd2 Add *_only and all/any pytree utilities (#83316)
With a sample usage in proxy tensor to show how they can shorten
your code dramatically.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83316
Approved by: https://github.com/zou3519, https://github.com/albanD, https://github.com/bdhirsh
2022-08-12 17:31:55 +00:00
6700a78504 Move vmap's OrderedDict pytree support to torch.utils._pytree (#83073)
There's no reason why it should just apply when you import vmap

Test Plan:
- added a new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83073
Approved by: https://github.com/Chillee
2022-08-11 03:00:55 +00:00
18c74d10bb Register torch.return_types.* as pytree nodes
All of the torch.return_types.* are these special things "structseq"
that subclass tuple but have a different constructor from tuple :(.

This PR iterates through all of torch.return_types.* and adds a pytree
registration for them.

Test Plan:
- add tests for max and min which return torch.return_types.max, and
torch.return_types.min, respectively. There's not an easy way to
"get all torch ops that return a return_types object".

Fixes https://github.com/pytorch/pytorch/issues/75218

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75915
Approved by: https://github.com/ezyang, https://github.com/kshitij12345
2022-04-19 13:46:20 +00:00
f9ccf7ab80 [skip ci] Set pytree tests to module: pytree owner (#74686)
Summary:
Based on zou3519's suggestions!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74686

Reviewed By: dagitses

Differential Revision: D35255764

Pulled By: janeyx99

fbshipit-source-id: 9fe521d6b3b9d6620ad3f06758b1f0d20f9408ad
(cherry picked from commit 6954e1705ccdc9686bc777aaf2e98a922b662946)
2022-03-31 04:25:43 +00:00
a1e284d9c8 Remove high priority as an owner for tests (#74555)
Summary:
Following triage review discussion, it would be best for these tests to not be triaged high priority by automation, but by the triagers in the oncall.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74555

Reviewed By: albanD

Differential Revision: D35099202

Pulled By: janeyx99

fbshipit-source-id: 657a0317141de3a598476a6f601ec26cc26231b1
(cherry picked from commit 057519cb2494d0f9a0b169f359ac87ba9e89f088)
2022-03-24 14:29:52 +00:00
c19cda5782 [skip ci] Add test owners for a special hi-pri class of tests (#67553)
Summary:
Action following https://github.com/pytorch/pytorch/issues/66232

This change does require some context: there were several suggestions regarding what to do about this group of tests: tests that are core and crucial to all of PyTorch and are too broad to be owned by one team.
1. Let's add a "module: core" and put people behind it! This idea sounds appealing unless you are one of the people backing the label. From talking to albanD among others, this idea of putting all these core tests on the shoulder of a few people or one team isn't super fair and I have not yet found anyone willing to take on this job.
2. Taking advantage of the fact that we already have a triaging oncall that takes turns triaging issues, we can leave these tests essentially unlabeled and allow the oncall to triage these tests. Since these tests are crucial to PyTorch, we'll add the "high priority" label to mark them different from other unowned tests (see https://github.com/pytorch/pytorch/issues/67552).
3. I _could_ still create an unbacked label "module: core" and attribute these tests there, but I don't like the idea of creating a facade that the tests are "triaged" to a label when no one is actually taking a look.

Now we could potentially break these tests down into smaller files so that each piece _could_ be owned by a team, but 1. I don't know if this is currently feasible and 2. This approach does not prevent that from happening in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67553

Reviewed By: albanD

Differential Revision: D32025004

Pulled By: janeyx99

fbshipit-source-id: 1fb1aa4c27e305695ab6e80ae3d02f90519939c0
2021-10-29 12:17:21 -07:00
1022443168 Revert D30279364: [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: revert-hammer

Differential Revision:
D30279364 (b004307252)

Original commit changeset: c1ed77dfe43a

fbshipit-source-id: eab50857675c51e0088391af06ec0ecb14e2347e
2021-08-12 11:45:01 -07:00
b004307252 [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle

Reviewed By: zertosh

Differential Revision: D30279364

fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
2021-08-12 10:58:35 -07:00
52d1ffb789 Teach pytrees about namedtuple (#62292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292

This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.

Test Plan
- new tests in test/test_pytree.py and test/test_fx.py

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29947302

Pulled By: zou3519

fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
2021-07-28 06:27:44 -07:00
8d363d37da [FX] Adds PyTree support to FX through concrete_args (#55888)
Summary:
```
class Foo(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y, x):
        for k in x:
            for v in x[k]:
                v += y
        return x

example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})

fx.symbolic_trace(new_f, concrete_args=example_dict)
```

prints out
```
def forward(self, y, x):
    y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
    add = tree_2 + y
    add_1 = tree_3 + y
    add_2 = tree_4 + y;  y = None
    return {'a': [tree_2], 'z': [tree_3, tree_4]}
```

Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.

Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888

Reviewed By: jamesr66a

Differential Revision: D27884694

Pulled By: Chillee

fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
2021-05-07 04:48:35 -07:00
6025f8148a Implement _broadcast_to_and_flatten(pytree, spec) (#46288)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46288

This "broadcasts" `pytree` to have the same structure as `spec`
and then flattens it.
I find it hard to describe what that does in words, so here's an example:

- Broadcasting 1 to have the same structure as [0, [0, 0]] would
return [1, [1, 1]]. Further flattening it gives us [1, 1, 1].
- Broadcasting [1, 2] to have the same structure as [0, [0, 0]] would
return [1, [2, 2]]. Further flattening it gives us [1, 2, 2].

What is this used for?
----------------------
The next PR up in the stack uses this helper function to allow vmap to
accept nested data structures. `vmap(fn, in_dims)(*inputs)` allows the
user to specify in_dims with a tree structure that is a sub-graph of
that of `inputs` (where both contain the root of the tree).

For example, one can do `vmap(fn, in_dims=0)(x, y, z)`. `in_dims` is 0
and inputs is (x, y, z). We would like to broadcast in_dims up to the
structure of inputs to get (0, 0, 0).

Another example, is `vmap(fn, in_dims=(0, 1))(x, [y, z])`. `in_dims` is
(0, 1) and inputs is (x, [y, z]). We would like to broadcast in_dims up
to the structure of inputs to get (0, [1, 1]); this value of in_dims is
used to say "let's vmap over dim 0 for x and dim 1 for y and z".

Test Plan
---------
New tests.

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392891

Pulled By: zou3519

fbshipit-source-id: 6f494d8b6359582f1b4ab6b8dd6a956d8bfe8ed4
2020-10-20 07:52:14 -07:00
0285618a11 Add utilities to support handling of nested python data structures (#46287)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46287

This adds a lightweight `pytree` implementation that is similar to and
inspired by JAX pytrees, tensorflow.nest, deepmind/tree,
TorchBeast's TensorNest, etc.

A *pytree* is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values. Furthermore, a pytree should not contain reference
cycles.

This PR:
- adds support for flattening and unflattening nested Python list/dict/tuples

Context: nested Tensor inputs for vmap
--------------------------------------
Right now, vmap is restricted to taking in flat lists of tensors. This
is because vmap needs to be able to convert every tensor in the input
that is being vmapped over into a BatchedTensor.

With a pytree library, we can simply flatten the input data structure
(returning the leaves), map all of the Tensors in the flat input to
BatchedTensors, and unflatten the flat list of BatchedTensors into a new
input. Or equivalently, with a `tree_map` function, we can map a nested
python data structure containing Tensors into one containing
BatchedTensors.

Future work
-----------
In some future PRs, we'll add nested input support for vmap. The
prerequisites for that are:
- a `broadcast_to(small, big)` that broadcasts `small` up to `big`.
  This is for handling the in_dims to vmap: the in_dims structure must
  be compatible with the structure of the inputs.

Test Plan
---------
- New tests in test/test_pytree.py

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392890

Pulled By: zou3519

fbshipit-source-id: 7daf7430c5a38354e7d203a72882bd7a9b24cfb1
2020-10-20 07:45:45 -07:00