2c4b5968cf
[functorch] Add xlogy batching rule + UNARY_SCALAR macro
2022-07-21 13:41:00 -07:00
1b5cd8e227
[functorch] Fixed some more vjp failures
2022-07-21 13:41:00 -07:00
5de1f0c0a1
[functorch] Mention that users should install ninja to speed up the build
2022-07-21 13:41:00 -07:00
227d0c9c2b
[functorch] Fix bug in softmax/log_softmax related to scalars + refactored code to be nicer
2022-07-21 13:41:00 -07:00
e1a342f0ba
[functorch] Fix op db + nnc_compile failing tests
2022-07-21 13:41:00 -07:00
a3eb2ef978
[functorch] Added nn.functional.mse_loss batching rule
2022-07-21 13:41:00 -07:00
5466ac5da3
[functorch] Update eager_fusion.py
2022-07-21 13:41:00 -07:00
02713fd157
[functorch] Update README.md
2022-07-21 13:41:00 -07:00
93fb561a04
[functorch] _to_copy batch rule
2022-07-21 13:41:00 -07:00
d4ada7c7c3
[functorch] vmapvjpvjp testing
2022-07-21 13:40:59 -07:00
e8a5b3725b
[functorch] fix example to work well :)
2022-07-21 13:40:59 -07:00
1a2e538580
[functorch] Moved to Keops example
2022-07-21 13:40:59 -07:00
9ca5ae86d0
[functorch] Added eager-mode fusion prototype
2022-07-21 13:40:59 -07:00
e632a46ff0
[functorch] Added initial NNC example of eager-mode fusion
2022-07-21 13:40:59 -07:00
7c405f44b1
[functorch] Fix "vmap: inplace arithmetic(self, *extra_args) is not possible" for linear
2022-07-21 13:40:59 -07:00
a9327ea80e
[functorch] Revert TensorWrapper changes to make things more sane
2022-07-21 13:40:59 -07:00
fd7e524b4e
[functorch] update binary_pointwise_batch_rule
2022-07-21 13:40:59 -07:00
4ce294d25c
[functorch] Quick grab bag of batching rules
2022-07-21 13:40:59 -07:00
0a7954b9d4
[functorch] Update EXISTING_BDIM_BATCH_RULE
2022-07-21 13:40:59 -07:00
3a1a59a74b
[functorch] Updated variadic_bdims_batch_rule
2022-07-21 13:40:59 -07:00
745c633687
[functorch] replace basic_unary_batch_rule with BASIC_UNARY_BATCH_RULE
...
For the new macro there is no need to pass in the additional non-tensor
arguments. This is dome by some mix of tensor metaprogramming.
2022-07-21 13:40:59 -07:00
7e1d730a4f
[functorch] Updated squeeze and squeeze.dim batching rules to new style and added scalar handling ( pytorch/functorch#81 )
2022-07-21 13:40:59 -07:00
53155d3ba0
[functorch] remove old python key implementation :)
2022-07-21 13:40:59 -07:00
707c26d3db
[functorch] Added vjpfull
2022-07-21 13:40:59 -07:00
a92ca843d2
[functorch] updated nnc_compile to work with new python key
2022-07-21 13:40:59 -07:00
a11bdcc411
[functorch] pythonkey refactor
2022-07-21 13:40:59 -07:00
8107b13b1e
[functorch] Tensor printing
2022-07-21 13:40:59 -07:00
b3e39c1968
[functorch] Fix reshape failure by adding _reshape_alias batch rule
...
Not sure how to deal with the nnc lowerings though.
2022-07-21 13:40:59 -07:00
f6bb9acdbc
[functorch] Add interpolate/upsample batching rules
2022-07-21 13:40:59 -07:00
308f477598
[functorch] Finished most of the rest of the pad batching rules (also added an existing_batch_dim_template)
2022-07-21 13:40:59 -07:00
a5a7245e46
[functorch] Added batching rules for constant_pad_nd
2022-07-21 13:40:59 -07:00
266d2ced94
[functorch] updated opinfo db and fixed failing tests
2022-07-21 13:40:58 -07:00
ed65f0a83a
[functorch] Revert "We actually fixed a lot of the vjp problems"
...
This reverts commit 3d8d4504569bc335e15f67d926bfa73b13a4618b.
2022-07-21 13:40:58 -07:00
22ad9d473a
[functorch] We actually fixed a lot of the vjp problems
2022-07-21 13:40:58 -07:00
2ceba07ea5
[functorch] Resolved some vjp failures
2022-07-21 13:40:58 -07:00
5cfa5728eb
[functorch] Normalize, cross_entropy OpInfos
2022-07-21 13:40:58 -07:00
f8caad7fb1
[functorch] OpInfo for pad
2022-07-21 13:40:58 -07:00
4028de1d45
[functorch] additional OpInfo for Conv2d
2022-07-21 13:40:58 -07:00
f28e199609
[functorch] More OpInfos
2022-07-21 13:40:58 -07:00
6b59f1ad78
[functorch] Add functorch_additional_op_db
...
PyTorch OpInfo coverage doesn't include nn.functional.* ops. We can use
functorch_additional_op_db as a staging ground where we implement
nn.functional.* OpInfos and then later upstream them.
Why not just upstream them immediately? I can write OpInfos a few times
faster if I don't have to worry about figuring out what the correct
flags are to pass PyTorch tests...
2022-07-21 13:40:58 -07:00
d506951937
[functorch] Some more batch rules for pointwise ops
2022-07-21 13:40:58 -07:00
983a43cfc9
[functorch] batch rules for torch.special unary ops
2022-07-21 13:40:58 -07:00
1b78cae7b6
[functorch] Quick grab bag of batch rules
2022-07-21 13:40:58 -07:00
6ca3e96eef
[functorch] Fix CI; for real this time
2022-07-21 13:40:58 -07:00
2e7ddb7a86
[functorch] Fix ci
2022-07-21 13:40:58 -07:00
6d39fa335b
[functorch] Added some make_fx+vjp/jac/vmap tests
2022-07-21 13:40:58 -07:00
8e62e271be
[functorch] Add make_fx(grad(..)) test
2022-07-21 13:40:58 -07:00
236d2f20b6
[functorch] Update README.md
2022-07-21 13:40:58 -07:00
046453c66b
[functorch] Added a quick benchmark
2022-07-21 13:40:58 -07:00
8d127816d3
[functorch] Added citation section
2022-07-21 13:40:58 -07:00