66ee9e96a3
[functorch] Add opinfo based testing for grad transform
2022-07-21 13:40:53 -07:00
74c7c73672
[functorch] Kwarg support for grad
2022-07-21 13:40:53 -07:00
2f3f44c302
[functorch] cleaned up logic
2022-07-21 13:40:53 -07:00
a92f492b9c
[functorch] removed extraneous pdb
2022-07-21 13:40:53 -07:00
0e3a2b2d5c
[functorch] made tensors const references - fixes pytorch/functorch#38
2022-07-21 13:40:53 -07:00
ea85c20c35
[functorch] Make the tensor wrap/unwrap logic aware of Tensor?[]
...
Tensor?[] is an argument to index_put_ from advanced indexing.
Previously we weren't peeking inside Tensor?[] for tensors.
Fixes pytorch/functorch#31
2022-07-21 13:40:53 -07:00
fc5c632ab5
[functorch] fix some compile warnings
2022-07-21 13:40:53 -07:00
a24930ba93
[functorch] Fix being unable to call .backward() after vmap
...
Fixes pytorch/functorch#37
2022-07-21 13:40:53 -07:00
bf1df4f3af
[functorch] ported sumdim batch rule over and added argmax
2022-07-21 13:40:53 -07:00
74d0250734
[functorch] fix vmap tests
2022-07-21 13:40:53 -07:00
fbc76eb6da
[functorch] Add mode dispatch stack ( pytorch/functorch#34 )
...
* Add mode dispatch stack
* remove some things
* fix some type errors and clean up vmap
* Added a warning
2022-07-21 13:40:53 -07:00
45d4228334
[functorch] Added wraps to grad_and_value
2022-07-21 13:40:53 -07:00
470ecce6e5
[functorch] nll_loss_backward batch rule for some cases
2022-07-21 13:40:53 -07:00
221bdfba33
[functorch] Added a MKLDNN decomposition and new_ones overload
2022-07-21 13:40:53 -07:00
6b1cc7f499
[functorch] added reshape lowerings to nnc
2022-07-21 13:40:53 -07:00
a7f406ce58
[functorch] Quick attempt at hiding functional module init
...
Introduces a `functional_init` and `functional_init_with_buffers` that
lets one initialize an ensemble of modules more easily than before. This
was done in the sprit of make_functional: the API still looks awkward,
especially when buffers are involved.
2022-07-21 13:40:53 -07:00
7f344c5a0b
[functorch] Add batch rule for nll_loss_forward for most common cases
2022-07-21 13:40:53 -07:00
1cd0dd00ce
[functorch] Linear batch rule (which is just a decomposition)
2022-07-21 13:40:53 -07:00
9496ea3e8a
[functorch] fix build failure
2022-07-21 13:40:53 -07:00
2910423017
[functorch] Update some version numbers"
2022-07-21 13:40:53 -07:00
9e6201db9a
[functorch] lennard-jones example and test
2022-07-21 13:40:53 -07:00
d65bb48b46
[functorch] Added a way to call the slow fallback from plumbing
2022-07-21 13:40:53 -07:00
03173fad44
[functorch] Inplace +-*/ batch rules
2022-07-21 13:40:53 -07:00
8e82d5afd1
[functorch] fix issues with passthrough variables
2022-07-21 13:40:53 -07:00
7644e62d11
[functorch] Added statement if there's an empty NNC compute expression
2022-07-21 13:40:53 -07:00
f2204f6045
[functorch] Added some tests + fix some stupid stuff
2022-07-21 13:40:52 -07:00
e1281fca60
[functorch] Fix some failing tests
2022-07-21 13:40:52 -07:00
29b90b4f4f
[functorch] Added conv2d batching rule ( pytorch/functorch#10 )
...
* Added conv2d batching rule
* resolve some comments
* added some stuff to conv
* remove debugging stuff
* remove clamp + add squeeze to vmap failures
* Added more conv batching rules
2022-07-21 13:40:52 -07:00
354e79afc5
[functorch] log softmax backward batch rule
2022-07-21 13:40:52 -07:00
8e0e341076
[functorch] Fix some of the more important lint errors
2022-07-21 13:40:52 -07:00
69e7dccc25
[functorch] switch from TORCH_INTERNAL_ASSERT to runtime_error
2022-07-21 13:40:52 -07:00
f7662a2101
[functorch] test/common.py -> test/common_utils.py
...
"from common import ..." can fail pretty easily because "common" is a
common name.
2022-07-21 13:40:52 -07:00
31a3d45c88
[functorch] Added aten::diag and aten::rsub.Scalar batching rules
2022-07-21 13:40:52 -07:00
2ddf7b0f70
[functorch] Switch PythonTensorImpl to using custom contiguity policy
...
For build reasons.
2022-07-21 13:40:52 -07:00
45a4dda3ad
[functorch] Switch to overriding is_contiguous_custom
...
There's something weird with building this internally otherwise...
2022-07-21 13:40:52 -07:00
38eb311988
[functorch] Fix clamp_min / clamp_max
2022-07-21 13:40:52 -07:00
bb701ef563
[functorch] Roll back clamp_min / clamp_max change
...
It doesn't compile under clang. Going to investigate this later
2022-07-21 13:40:52 -07:00
256099c1a6
[functorch] Switched over to pytorch core's tree_map
2022-07-21 13:40:52 -07:00
66081519da
[functorch] Removed WrapModule
2022-07-21 13:40:52 -07:00
f0f15bc84a
[functorch] Batching rules for: threshold_backward, clamp_min, clamp_max
2022-07-21 13:40:52 -07:00
d9d5a52a14
[functorch] Update README's torch nightly version
2022-07-21 13:40:52 -07:00
5e90f1e61b
[functorch] Parameterized testing
2022-07-21 13:40:52 -07:00
647ce0ab8d
[functorch] removed unnecessary code left in
2022-07-21 13:40:52 -07:00
69ae2b6dd0
[functorch] fix oversight with reusing the same storage
2022-07-21 13:40:52 -07:00
26d2ab7e35
[functorch] updated README
2022-07-21 13:40:52 -07:00
ac017463e0
[functorch] Added a simple function example
2022-07-21 13:40:52 -07:00
25453d611c
[functorch] Added nnc compilation stuff to functorch
2022-07-21 13:40:52 -07:00
4b0d62a2af
[functorch] Migrate mm and mv batch rules from old style to new style
2022-07-21 13:40:52 -07:00
9bdd9cee5d
[functorch] reshape_dim_into and reshape_dim_outof helpers
2022-07-21 13:40:52 -07:00
0c5755fcaa
[functorch] Delete some redundant code
2022-07-21 13:40:52 -07:00