1694 Commits

Author SHA1 Message Date
66ee9e96a3 [functorch] Add opinfo based testing for grad transform 2022-07-21 13:40:53 -07:00
74c7c73672 [functorch] Kwarg support for grad 2022-07-21 13:40:53 -07:00
2f3f44c302 [functorch] cleaned up logic 2022-07-21 13:40:53 -07:00
a92f492b9c [functorch] removed extraneous pdb 2022-07-21 13:40:53 -07:00
0e3a2b2d5c [functorch] made tensors const references - fixes pytorch/functorch#38 2022-07-21 13:40:53 -07:00
ea85c20c35 [functorch] Make the tensor wrap/unwrap logic aware of Tensor?[]
Tensor?[] is an argument to index_put_ from advanced indexing.
Previously we weren't peeking inside Tensor?[] for tensors.

Fixes pytorch/functorch#31
2022-07-21 13:40:53 -07:00
fc5c632ab5 [functorch] fix some compile warnings 2022-07-21 13:40:53 -07:00
a24930ba93 [functorch] Fix being unable to call .backward() after vmap
Fixes pytorch/functorch#37
2022-07-21 13:40:53 -07:00
bf1df4f3af [functorch] ported sumdim batch rule over and added argmax 2022-07-21 13:40:53 -07:00
74d0250734 [functorch] fix vmap tests 2022-07-21 13:40:53 -07:00
fbc76eb6da [functorch] Add mode dispatch stack (pytorch/functorch#34)
* Add mode dispatch stack

* remove some things

* fix some type errors and clean up vmap

* Added a warning
2022-07-21 13:40:53 -07:00
45d4228334 [functorch] Added wraps to grad_and_value 2022-07-21 13:40:53 -07:00
470ecce6e5 [functorch] nll_loss_backward batch rule for some cases 2022-07-21 13:40:53 -07:00
221bdfba33 [functorch] Added a MKLDNN decomposition and new_ones overload 2022-07-21 13:40:53 -07:00
6b1cc7f499 [functorch] added reshape lowerings to nnc 2022-07-21 13:40:53 -07:00
a7f406ce58 [functorch] Quick attempt at hiding functional module init
Introduces a `functional_init` and `functional_init_with_buffers` that
lets one initialize an ensemble of modules more easily than before. This
was done in the sprit of make_functional: the API still looks awkward,
especially when buffers are involved.
2022-07-21 13:40:53 -07:00
7f344c5a0b [functorch] Add batch rule for nll_loss_forward for most common cases 2022-07-21 13:40:53 -07:00
1cd0dd00ce [functorch] Linear batch rule (which is just a decomposition) 2022-07-21 13:40:53 -07:00
9496ea3e8a [functorch] fix build failure 2022-07-21 13:40:53 -07:00
2910423017 [functorch] Update some version numbers" 2022-07-21 13:40:53 -07:00
9e6201db9a [functorch] lennard-jones example and test 2022-07-21 13:40:53 -07:00
d65bb48b46 [functorch] Added a way to call the slow fallback from plumbing 2022-07-21 13:40:53 -07:00
03173fad44 [functorch] Inplace +-*/ batch rules 2022-07-21 13:40:53 -07:00
8e82d5afd1 [functorch] fix issues with passthrough variables 2022-07-21 13:40:53 -07:00
7644e62d11 [functorch] Added statement if there's an empty NNC compute expression 2022-07-21 13:40:53 -07:00
f2204f6045 [functorch] Added some tests + fix some stupid stuff 2022-07-21 13:40:52 -07:00
e1281fca60 [functorch] Fix some failing tests 2022-07-21 13:40:52 -07:00
29b90b4f4f [functorch] Added conv2d batching rule (pytorch/functorch#10)
* Added conv2d batching rule

* resolve some comments

* added some stuff to conv

* remove debugging stuff

* remove clamp + add squeeze to vmap failures

* Added more conv batching rules
2022-07-21 13:40:52 -07:00
354e79afc5 [functorch] log softmax backward batch rule 2022-07-21 13:40:52 -07:00
8e0e341076 [functorch] Fix some of the more important lint errors 2022-07-21 13:40:52 -07:00
69e7dccc25 [functorch] switch from TORCH_INTERNAL_ASSERT to runtime_error 2022-07-21 13:40:52 -07:00
f7662a2101 [functorch] test/common.py -> test/common_utils.py
"from common import ..." can fail pretty easily because "common" is a
common name.
2022-07-21 13:40:52 -07:00
31a3d45c88 [functorch] Added aten::diag and aten::rsub.Scalar batching rules 2022-07-21 13:40:52 -07:00
2ddf7b0f70 [functorch] Switch PythonTensorImpl to using custom contiguity policy
For build reasons.
2022-07-21 13:40:52 -07:00
45a4dda3ad [functorch] Switch to overriding is_contiguous_custom
There's something weird with building this internally otherwise...
2022-07-21 13:40:52 -07:00
38eb311988 [functorch] Fix clamp_min / clamp_max 2022-07-21 13:40:52 -07:00
bb701ef563 [functorch] Roll back clamp_min / clamp_max change
It doesn't compile under clang. Going to investigate this later
2022-07-21 13:40:52 -07:00
256099c1a6 [functorch] Switched over to pytorch core's tree_map 2022-07-21 13:40:52 -07:00
66081519da [functorch] Removed WrapModule 2022-07-21 13:40:52 -07:00
f0f15bc84a [functorch] Batching rules for: threshold_backward, clamp_min, clamp_max 2022-07-21 13:40:52 -07:00
d9d5a52a14 [functorch] Update README's torch nightly version 2022-07-21 13:40:52 -07:00
5e90f1e61b [functorch] Parameterized testing 2022-07-21 13:40:52 -07:00
647ce0ab8d [functorch] removed unnecessary code left in 2022-07-21 13:40:52 -07:00
69ae2b6dd0 [functorch] fix oversight with reusing the same storage 2022-07-21 13:40:52 -07:00
26d2ab7e35 [functorch] updated README 2022-07-21 13:40:52 -07:00
ac017463e0 [functorch] Added a simple function example 2022-07-21 13:40:52 -07:00
25453d611c [functorch] Added nnc compilation stuff to functorch 2022-07-21 13:40:52 -07:00
4b0d62a2af [functorch] Migrate mm and mv batch rules from old style to new style 2022-07-21 13:40:52 -07:00
9bdd9cee5d [functorch] reshape_dim_into and reshape_dim_outof helpers 2022-07-21 13:40:52 -07:00
0c5755fcaa [functorch] Delete some redundant code 2022-07-21 13:40:52 -07:00