1694 Commits

Author SHA1 Message Date
57f48cc691 [functorch] Fix parametrized testing
Using a quick hack. Pytorch core is getting parameterized testing (soon)
so we'll replace our hacky one with that once it is in.
2022-07-21 13:40:58 -07:00
fea6254d71 [functorch] Fix grad wrapper behavior when wrapping Python tensor
Fixes pytorch/functorch#68
2022-07-21 13:40:58 -07:00
e836870e07 [functorch] removed comment that previous commit fixed :) 2022-07-21 13:40:58 -07:00
d71bb37414 [functorch] ported random registrations to boxedfallback 2022-07-21 13:40:58 -07:00
b28268c8f4 [functorch] Added errors for dynamic ops 2022-07-21 13:40:58 -07:00
899fea5eab [functorch] Batch rule for resize_ when batch dim is at front
I need this to test some things regarding resize_ behavior
2022-07-21 13:40:58 -07:00
0bba49c8a7 [functorch] Added dropout (tested by opinfo to be upstreamed) 2022-07-21 13:40:57 -07:00
a598ed6be6 [functorch] Added softmax batching rule (tested through local OpInfo, to be upstreamed) 2022-07-21 13:40:57 -07:00
abd8c710b2 [functorch] Added dist decomposition 2022-07-21 13:40:57 -07:00
23fc2e0f6e [functorch] moved some decompositions from batchingregistrations.cpp out 2022-07-21 13:40:57 -07:00
71042b1b16 [functorch] added remove-inplace pass to nnc_jit (kinda works, but won't work with aliasing) 2022-07-21 13:40:57 -07:00
abc520f804 [functorch] Update dependency (pytorch/functorch#58) 2022-07-21 13:40:57 -07:00
a9ac8e814c [functorch] (Grad)TensorWrapper sometimes has storage (and data_ptr) (pytorch/functorch#65)
TensorWrapper can have storage now:
- Case 1: If it wraps a tensor with storage, it will have storage
- Case 2: If it wraps a tensor without storage, it will not have storage

The rationale for this is to fix pytorch/functorch#7. When torch.tensor gets called, the
following happens:
- at::empty gets called
- some data from a PyObject* gets written directly into the new empty
tensor

The previous problem was that `at::empty` would return a TensorWrapper
wrapping a regular Tensor and that TensorWrapper did not have
storage/data_ptr.

It should be fine that TensorWrapper sometimes has storage. Users should
not write directly to the .data_ptr (because that would cause gradients
to be incorrect, but it is the same in regular PyTorch).

Test Plan:
- wait for tests
2022-07-21 13:40:57 -07:00
26503f21e1 [functorch] Install expecttest as CI step (pytorch/functorch#66) 2022-07-21 13:40:57 -07:00
74f773192f [functorch] added ability for module attrs to be directly referenced in output 2022-07-21 13:40:57 -07:00
f10845eeaa [functorch] Added binary_cross_entropy lowerings 2022-07-21 13:40:57 -07:00
379ae35ef2 [functorch] fixed device issues in tests 2022-07-21 13:40:57 -07:00
f561f0f665 [functorch] add no-dim for batchrulesreduce and move trace from old batching rule to new decomposition 2022-07-21 13:40:57 -07:00
0a8ef33ca8 [functorch] Removed no-dim rules for batchrulespooling and batchrulesviews 2022-07-21 13:40:57 -07:00
8f06e91546 [functorch] removed no bdim batching rules from batchrulesbinaryops 2022-07-21 13:40:57 -07:00
cb286b9b49 [functorch] removed no bdim cases from batchruleslinearalgebra 2022-07-21 13:40:57 -07:00
8cd80e0b16 [functorch] Added a test for case where we would previously have dispatched to a batching rule despite having no bdims 2022-07-21 13:40:57 -07:00
446d0b4e4d [functorch] Selectively enable dispatch on kBatchedKey (pytorch/functorch#63)
This PR makes it so that dispatch on kBatchedKey only can happen if there are
tensors batched at the current level. Otherwise, kBatchedKey is excluded
(even if there are BatchedTensors!).

To find tensors batched at the current level, we check:
- all tensor arguments
- we peek into all TensorLists
- we peek into all Tensor?[].
the above bullet points should be sufficient.

Dispatch for kVmapModeKey is not affected.

Test Plan:
- run all tests
- removed the special case in dot_batch_rule and added a test
2022-07-21 13:40:57 -07:00
bbfdcdbd79 [functorch] Added einsum batching rule 2022-07-21 13:40:57 -07:00
495112f550 [functorch] updated opinfo DB and added meshgrid batching rule 2022-07-21 13:40:57 -07:00
f6667347a2 [functorch] Exclude failing tests, add xfail test for torch.tensor 2022-07-21 13:40:57 -07:00
c02cc07c96 [functorch] Add flip batching rule 2022-07-21 13:40:57 -07:00
b20e4decc4 [functorch] Added batching rule for logsumexp 2022-07-21 13:40:57 -07:00
9c138786b7 [functorch] Added norm.Scalar and norm.ScalarOpt_dim 2022-07-21 13:40:57 -07:00
65c54e282c [functorch] Revert "Exclude List[Optional[Tensor]] from the batched fallback"
This reverts commit 3f5f75208aa547a2f4ed23fd281c917d327b9819.
2022-07-21 13:40:57 -07:00
6916cc5d5b [functorch] Added full_like and refactored factory macros a bit 2022-07-21 13:40:57 -07:00
b815b5bc6b [functorch] Added triu/tril 2022-07-21 13:40:56 -07:00
0aedd9e8c1 [functorch] finished batching rules for div 2022-07-21 13:40:56 -07:00
345cf3ebf2 [functorch] Added sort batching rules 2022-07-21 13:40:56 -07:00
94a55b7ead [functorch] finished off where batching rules + added OP_DECOMPOSE macro 2022-07-21 13:40:56 -07:00
7b8478332b [functorch] Added where/_s_where, changed flatten (a composite op) to decomposition, and added (var/std).correction 2022-07-21 13:40:56 -07:00
60fbba9d0e [functorch] Added erf/inverse/isinf/isnan batching rules 2022-07-21 13:40:56 -07:00
58319030bd [functorch] Added isnan + topk batching rules 2022-07-21 13:40:56 -07:00
7bb63ff2ef [functorch] Added clamp.tensor batching rule 2022-07-21 13:40:56 -07:00
f787f0be9d [functorch] Added mode batching rule 2022-07-21 13:40:56 -07:00
4f87a0694e [functorch] Added max/min/prod (no dim) 2022-07-21 13:40:56 -07:00
85587011a8 [functorch] Exclude List[Optional[Tensor]] from the batched fallback 2022-07-21 13:40:56 -07:00
9094e757a0 [functorch] Added cumprod/cumsum and ported over log_softmax 2022-07-21 13:40:56 -07:00
ec25616e0c [functorch] Added maximum/minimum/clamp batching rules 2022-07-21 13:40:56 -07:00
9b00f55a46 [functorch] Added bmm batch rule 2022-07-21 13:40:56 -07:00
e9c196e4f7 [functorch] Add deg2rad/rad2deg/radian batching rules 2022-07-21 13:40:56 -07:00
ede0952ee4 [functorch] add acosh/atanh/asinh batching rules 2022-07-21 13:40:56 -07:00
ed2f58242f [functorch] fix some bugs with reduction batching rule + add amax/amin 2022-07-21 13:40:56 -07:00
43de84e088 [functorch] Added sigmoid_backward and replaced a bunch of typedefs with decltypes 2022-07-21 13:40:56 -07:00
5af36052cf [functorch] Added atan2 and zeros_ batching rules 2022-07-21 13:40:56 -07:00