1694 Commits

Author SHA1 Message Date
2c4b5968cf [functorch] Add xlogy batching rule + UNARY_SCALAR macro 2022-07-21 13:41:00 -07:00
1b5cd8e227 [functorch] Fixed some more vjp failures 2022-07-21 13:41:00 -07:00
5de1f0c0a1 [functorch] Mention that users should install ninja to speed up the build 2022-07-21 13:41:00 -07:00
227d0c9c2b [functorch] Fix bug in softmax/log_softmax related to scalars + refactored code to be nicer 2022-07-21 13:41:00 -07:00
e1a342f0ba [functorch] Fix op db + nnc_compile failing tests 2022-07-21 13:41:00 -07:00
a3eb2ef978 [functorch] Added nn.functional.mse_loss batching rule 2022-07-21 13:41:00 -07:00
5466ac5da3 [functorch] Update eager_fusion.py 2022-07-21 13:41:00 -07:00
02713fd157 [functorch] Update README.md 2022-07-21 13:41:00 -07:00
93fb561a04 [functorch] _to_copy batch rule 2022-07-21 13:41:00 -07:00
d4ada7c7c3 [functorch] vmapvjpvjp testing 2022-07-21 13:40:59 -07:00
e8a5b3725b [functorch] fix example to work well :) 2022-07-21 13:40:59 -07:00
1a2e538580 [functorch] Moved to Keops example 2022-07-21 13:40:59 -07:00
9ca5ae86d0 [functorch] Added eager-mode fusion prototype 2022-07-21 13:40:59 -07:00
e632a46ff0 [functorch] Added initial NNC example of eager-mode fusion 2022-07-21 13:40:59 -07:00
7c405f44b1 [functorch] Fix "vmap: inplace arithmetic(self, *extra_args) is not possible" for linear 2022-07-21 13:40:59 -07:00
a9327ea80e [functorch] Revert TensorWrapper changes to make things more sane 2022-07-21 13:40:59 -07:00
fd7e524b4e [functorch] update binary_pointwise_batch_rule 2022-07-21 13:40:59 -07:00
4ce294d25c [functorch] Quick grab bag of batching rules 2022-07-21 13:40:59 -07:00
0a7954b9d4 [functorch] Update EXISTING_BDIM_BATCH_RULE 2022-07-21 13:40:59 -07:00
3a1a59a74b [functorch] Updated variadic_bdims_batch_rule 2022-07-21 13:40:59 -07:00
745c633687 [functorch] replace basic_unary_batch_rule with BASIC_UNARY_BATCH_RULE
For the new macro there is no need to pass in the additional non-tensor
arguments. This is dome by some mix of tensor metaprogramming.
2022-07-21 13:40:59 -07:00
7e1d730a4f [functorch] Updated squeeze and squeeze.dim batching rules to new style and added scalar handling (pytorch/functorch#81) 2022-07-21 13:40:59 -07:00
53155d3ba0 [functorch] remove old python key implementation :) 2022-07-21 13:40:59 -07:00
707c26d3db [functorch] Added vjpfull 2022-07-21 13:40:59 -07:00
a92ca843d2 [functorch] updated nnc_compile to work with new python key 2022-07-21 13:40:59 -07:00
a11bdcc411 [functorch] pythonkey refactor 2022-07-21 13:40:59 -07:00
8107b13b1e [functorch] Tensor printing 2022-07-21 13:40:59 -07:00
b3e39c1968 [functorch] Fix reshape failure by adding _reshape_alias batch rule
Not sure how to deal with the nnc lowerings though.
2022-07-21 13:40:59 -07:00
f6bb9acdbc [functorch] Add interpolate/upsample batching rules 2022-07-21 13:40:59 -07:00
308f477598 [functorch] Finished most of the rest of the pad batching rules (also added an existing_batch_dim_template) 2022-07-21 13:40:59 -07:00
a5a7245e46 [functorch] Added batching rules for constant_pad_nd 2022-07-21 13:40:59 -07:00
266d2ced94 [functorch] updated opinfo db and fixed failing tests 2022-07-21 13:40:58 -07:00
ed65f0a83a [functorch] Revert "We actually fixed a lot of the vjp problems"
This reverts commit 3d8d4504569bc335e15f67d926bfa73b13a4618b.
2022-07-21 13:40:58 -07:00
22ad9d473a [functorch] We actually fixed a lot of the vjp problems 2022-07-21 13:40:58 -07:00
2ceba07ea5 [functorch] Resolved some vjp failures 2022-07-21 13:40:58 -07:00
5cfa5728eb [functorch] Normalize, cross_entropy OpInfos 2022-07-21 13:40:58 -07:00
f8caad7fb1 [functorch] OpInfo for pad 2022-07-21 13:40:58 -07:00
4028de1d45 [functorch] additional OpInfo for Conv2d 2022-07-21 13:40:58 -07:00
f28e199609 [functorch] More OpInfos 2022-07-21 13:40:58 -07:00
6b59f1ad78 [functorch] Add functorch_additional_op_db
PyTorch OpInfo coverage doesn't include nn.functional.* ops. We can use
functorch_additional_op_db as a staging ground where we implement
nn.functional.* OpInfos and then later upstream them.

Why not just upstream them immediately? I can write OpInfos a few times
faster if I don't have to worry about figuring out what the correct
flags are to pass PyTorch tests...
2022-07-21 13:40:58 -07:00
d506951937 [functorch] Some more batch rules for pointwise ops 2022-07-21 13:40:58 -07:00
983a43cfc9 [functorch] batch rules for torch.special unary ops 2022-07-21 13:40:58 -07:00
1b78cae7b6 [functorch] Quick grab bag of batch rules 2022-07-21 13:40:58 -07:00
6ca3e96eef [functorch] Fix CI; for real this time 2022-07-21 13:40:58 -07:00
2e7ddb7a86 [functorch] Fix ci 2022-07-21 13:40:58 -07:00
6d39fa335b [functorch] Added some make_fx+vjp/jac/vmap tests 2022-07-21 13:40:58 -07:00
8e62e271be [functorch] Add make_fx(grad(..)) test 2022-07-21 13:40:58 -07:00
236d2f20b6 [functorch] Update README.md 2022-07-21 13:40:58 -07:00
046453c66b [functorch] Added a quick benchmark 2022-07-21 13:40:58 -07:00
8d127816d3 [functorch] Added citation section 2022-07-21 13:40:58 -07:00