1694 Commits

Author SHA1 Message Date
56378e986c [functorch] Added batching rules for convolution, conv1d, conv2d, conv3d, etc. 2022-07-21 13:40:56 -07:00
1666d90161 [functorch] _unsafe_view batch rule 2022-07-21 13:40:56 -07:00
bea0df36c2 [functorch] Clarify the vmap fallback warning message
I got feedback from two users that they were unsure of if the warning
meant anything. The warning now clearly states that there is a
performance degredation and to file us an issue about it.
2022-07-21 13:40:56 -07:00
cfcd328ff7 [functorch] Added note about reaching out 2022-07-21 13:40:56 -07:00
0b60292122 [functorch] Quick optimization 2022-07-21 13:40:56 -07:00
5acd1b995d [functorch] linalg_eigh batch rule 2022-07-21 13:40:56 -07:00
75a64bb029 [functorch] Fix some headers 2022-07-21 13:40:55 -07:00
79e42fc15b [functorch] Implement more batch rules for resnet18 per-sample-grads 2022-07-21 13:40:55 -07:00
a054adeb74 [functorch] Fix build/test 2022-07-21 13:40:55 -07:00
dc5e0c0f58 [functorch] cudnn_convolution decomposition 2022-07-21 13:40:55 -07:00
eb788ae986 [functorch] maxpool_2d_with_indices_backward batch rule for specific case
This gives us coverage on all the batch rules for the cifar10 dp
example. We're slightly slower than opacus though, the numbers on my
machine are:
- 4 it/s for functorch vmap+grad
- 4.2 it/s for opacus

The differential should be investigated and I am also not sure if the
benchmarks are comparing the right things.
2022-07-21 13:40:55 -07:00
1262034bb7 [functorch] Implement per sample grad rule for cudnn_convolution_backward
This is one special case of the batching rule...
2022-07-21 13:40:55 -07:00
cfa9d98499 [functorch] quick fix 2022-07-21 13:40:55 -07:00
4a045e9659 [functorch] Introduce gen_plumbing.py
Problem: writing plumbing is repetitive and time consumign
Solution:
- run `gen_plumbing.py add.Tensor ~/pytorch/build/aten/src/ATen/RegistrationDeclarations.h`
- copy and paste the output into a source file.

In the long-run we don't want to be checking this codegen into the
codebase. However I haven't figured out what the design for the
long-term codegen should actually look like; how does one specify that
we want to *insert* some user-defined code into the middle of a
function? There are a few ideas:

Idea 1: ADD_TENSOR_PLUMBING_BEGIN and ADD_TENSOR_PLUMBING_END macros
```
ADD_TENSOR_PLUMBING_BEGIN
// your C++ logic here
ADD_TENSOR_PLUMBING_END
```

Idea 2: big .yaml file
```
- func: add.Tensor
// your C++ logic here

- func: add.Scalar
// your C++ logic here
```
2022-07-21 13:40:55 -07:00
03fbb542fc [functorch] Very experimental PyTorch forward-mode AD support
PyTorch forward mode AD doesn't support a lot of operations yet. I've
verified that `jvp` can compose with `vmap`, but unfortunately `jvp`
doesn't compose with itself.
2022-07-21 13:40:55 -07:00
e8c5f67cd8 [functorch] Setup circleci (pytorch/functorch#53)
* Fix some more std::make_tuple things

* Setup circleci

Scripts taken from pytorch/nestedtensor.
2022-07-21 13:40:55 -07:00
4aea806404 [functorch] Added enough things into denylists so that tests pass 2022-07-21 13:40:55 -07:00
3524505d7e [functorch] Create "lagging op database", use it in our OpInfo tests
We have a problem where our tests fail everytime we rebase to the most
recent version of PyTorch. It would be nice to distinguish between
"PyTorch broke a previously passing test" vs "PyTorch added a new test
that would have already failed on PyTorch"

The solution that this PR introduces is for functorch to maintain a
"lagging" OpInfo database. The lagging database needs to be updated
every once in a while with new OpInfos from pytorch/pytorch. This makes
it so that functorch does not randomly get new OpInfo tests.
2022-07-21 13:40:55 -07:00
fee49501db [functorch] Fix a lot of warnings; use c10::irange 2022-07-21 13:40:55 -07:00
e3429647d4 [functorch] Change initializer list tuple return to std::make_tuple 2022-07-21 13:40:55 -07:00
b29e666ade [functorch] [BC-breaking] Update make_functional* (pytorch/functorch#52)
Updates make_functional to use the new improved variants. The new
variants are superior in every way so we're replacing the previous
variants with this.

If someone wants the older variants, they can be found at:
- make_functional_with_buffers_deprecated_v1
- make_functional_deprecated_v1
2022-07-21 13:40:55 -07:00
fdcc680c9d [functorch] Update tests and examples to use make_functional*_v2
make_functional*_v2 is superior to the older make_functional. This PR
has all of our examples use it.

This PR also adds a "combine_state_for_ensemble(models)" API.

Coming soon: We're probably going to break BC on make_functional and
replace it with make_functional_v2. That's the nice thing about being a
prototype, we don't have to worry about BC too much.
2022-07-21 13:40:55 -07:00
ba1952c176 [functorch] Added some notes about pythonkey tracing to readme 2022-07-21 13:40:55 -07:00
58e7df77d3 [functorch] Added std batching rule 2022-07-21 13:40:55 -07:00
710d06c815 [functorch] templated out sum/mean/var batching rules and added nansum 2022-07-21 13:40:55 -07:00
1b90b429d7 [functorch] make_functional*_v2
Version 2 of make_functional.
2022-07-21 13:40:55 -07:00
4a20c215ce [functorch] vmap-of-vjp and vjp-of-vmap OpInfo testing
Plus some refactoring of the vmap testing to reuse functions between all
of the mentioned tests.

Fixes pytorch/functorch#28.
2022-07-21 13:40:55 -07:00
e58d7dde62 [functorch] Add newlines to eof 2022-07-21 13:40:55 -07:00
324f4d5e51 [functorch] Implement batch norm batch rule for one case (where everything is batched)
It's not clear how to write the other cases
2022-07-21 13:40:54 -07:00
7a1ee75ff3 [functorch] adaptive_avg_pool2d batch rule 2022-07-21 13:40:54 -07:00
804a901abc [functorch] Fix nightly binary links 2022-07-21 13:40:54 -07:00
c91ef2f13e [functorch] refactored argmax batching rule to support argmin too 2022-07-21 13:40:54 -07:00
039f98d0ea [functorch] Fixed argmax batching rules 2022-07-21 13:40:54 -07:00
9644436866 [functorch] Added a bunch of unary activation functions 2022-07-21 13:40:54 -07:00
f536c20b05 [functorch] fix bug with repeat batching rules, fixes pytorch/functorch#9 2022-07-21 13:40:54 -07:00
51408b2b32 [functorch] move some test info around 2022-07-21 13:40:54 -07:00
180214ddb9 [functorch] Fixed oversight in vmap tests, fixed bug in sum batching rule, and added var batching rule 2022-07-21 13:40:54 -07:00
1279ee4282 [functorch] add lowering for triangular_solve + getitem 2022-07-21 13:40:54 -07:00
bfcaed043d [functorch] Added get_ops and made lowering work for tuple outputs 2022-07-21 13:40:54 -07:00
6877541cc1 [functorch] Add some citations 2022-07-21 13:40:54 -07:00
b627d3516e [functorch] Citations 2022-07-21 13:40:54 -07:00
c7759445e5 [functorch] Add vjpvjp tests, fix pytorch/functorch#44
Bug was that TensorWrapper wasn't setting the storage_offset.
Unfortunate.
2022-07-21 13:40:54 -07:00
5ead88c7dc [functorch] Added code of conduct + contributing 2022-07-21 13:40:54 -07:00
06a722867e [functorch] Added a reference to the license in the README 2022-07-21 13:40:54 -07:00
c446037aa4 [functorch] Add LICENSE headers to code files 2022-07-21 13:40:54 -07:00
1a3a2cf1ca [functorch] Refactor vjp testing 2022-07-21 13:40:54 -07:00
dd2d217a3e [functorch] Added error checking to vjp, also added opinfo tests for vjp
Not surprisingly, vjp has the same problems as grad (but no more
problems). Maybe we can just run vjp tests instead of grad tests in the
future.
2022-07-21 13:40:54 -07:00
a49af32a6e [functorch] Fix incorrect vjp semantics 2022-07-21 13:40:54 -07:00
92a4886afa [functorch] Beef up grad testing, eliminate some false errors 2022-07-21 13:40:54 -07:00
7ddbbc392f [functorch] Enable colors in build log 2022-07-21 13:40:54 -07:00