Commit Graph

2 Commits

Author SHA1 Message Date
6025f8148a Implement _broadcast_to_and_flatten(pytree, spec) (#46288)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46288

This "broadcasts" `pytree` to have the same structure as `spec`
and then flattens it.
I find it hard to describe what that does in words, so here's an example:

- Broadcasting 1 to have the same structure as [0, [0, 0]] would
return [1, [1, 1]]. Further flattening it gives us [1, 1, 1].
- Broadcasting [1, 2] to have the same structure as [0, [0, 0]] would
return [1, [2, 2]]. Further flattening it gives us [1, 2, 2].

What is this used for?
----------------------
The next PR up in the stack uses this helper function to allow vmap to
accept nested data structures. `vmap(fn, in_dims)(*inputs)` allows the
user to specify in_dims with a tree structure that is a sub-graph of
that of `inputs` (where both contain the root of the tree).

For example, one can do `vmap(fn, in_dims=0)(x, y, z)`. `in_dims` is 0
and inputs is (x, y, z). We would like to broadcast in_dims up to the
structure of inputs to get (0, 0, 0).

Another example, is `vmap(fn, in_dims=(0, 1))(x, [y, z])`. `in_dims` is
(0, 1) and inputs is (x, [y, z]). We would like to broadcast in_dims up
to the structure of inputs to get (0, [1, 1]); this value of in_dims is
used to say "let's vmap over dim 0 for x and dim 1 for y and z".

Test Plan
---------
New tests.

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392891

Pulled By: zou3519

fbshipit-source-id: 6f494d8b6359582f1b4ab6b8dd6a956d8bfe8ed4
2020-10-20 07:52:14 -07:00
0285618a11 Add utilities to support handling of nested python data structures (#46287)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46287

This adds a lightweight `pytree` implementation that is similar to and
inspired by JAX pytrees, tensorflow.nest, deepmind/tree,
TorchBeast's TensorNest, etc.

A *pytree* is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values. Furthermore, a pytree should not contain reference
cycles.

This PR:
- adds support for flattening and unflattening nested Python list/dict/tuples

Context: nested Tensor inputs for vmap
--------------------------------------
Right now, vmap is restricted to taking in flat lists of tensors. This
is because vmap needs to be able to convert every tensor in the input
that is being vmapped over into a BatchedTensor.

With a pytree library, we can simply flatten the input data structure
(returning the leaves), map all of the Tensors in the flat input to
BatchedTensors, and unflatten the flat list of BatchedTensors into a new
input. Or equivalently, with a `tree_map` function, we can map a nested
python data structure containing Tensors into one containing
BatchedTensors.

Future work
-----------
In some future PRs, we'll add nested input support for vmap. The
prerequisites for that are:
- a `broadcast_to(small, big)` that broadcasts `small` up to `big`.
  This is for handling the in_dims to vmap: the in_dims structure must
  be compatible with the structure of the inputs.

Test Plan
---------
- New tests in test/test_pytree.py

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392890

Pulled By: zou3519

fbshipit-source-id: 7daf7430c5a38354e7d203a72882bd7a9b24cfb1
2020-10-20 07:45:45 -07:00