57f48cc691
[functorch] Fix parametrized testing
...
Using a quick hack. Pytorch core is getting parameterized testing (soon)
so we'll replace our hacky one with that once it is in.
2022-07-21 13:40:58 -07:00
fea6254d71
[functorch] Fix grad wrapper behavior when wrapping Python tensor
...
Fixes pytorch/functorch#68
2022-07-21 13:40:58 -07:00
e836870e07
[functorch] removed comment that previous commit fixed :)
2022-07-21 13:40:58 -07:00
d71bb37414
[functorch] ported random registrations to boxedfallback
2022-07-21 13:40:58 -07:00
b28268c8f4
[functorch] Added errors for dynamic ops
2022-07-21 13:40:58 -07:00
899fea5eab
[functorch] Batch rule for resize_ when batch dim is at front
...
I need this to test some things regarding resize_ behavior
2022-07-21 13:40:58 -07:00
0bba49c8a7
[functorch] Added dropout (tested by opinfo to be upstreamed)
2022-07-21 13:40:57 -07:00
a598ed6be6
[functorch] Added softmax batching rule (tested through local OpInfo, to be upstreamed)
2022-07-21 13:40:57 -07:00
abd8c710b2
[functorch] Added dist decomposition
2022-07-21 13:40:57 -07:00
23fc2e0f6e
[functorch] moved some decompositions from batchingregistrations.cpp out
2022-07-21 13:40:57 -07:00
71042b1b16
[functorch] added remove-inplace pass to nnc_jit (kinda works, but won't work with aliasing)
2022-07-21 13:40:57 -07:00
abc520f804
[functorch] Update dependency ( pytorch/functorch#58 )
2022-07-21 13:40:57 -07:00
a9ac8e814c
[functorch] (Grad)TensorWrapper sometimes has storage (and data_ptr) ( pytorch/functorch#65 )
...
TensorWrapper can have storage now:
- Case 1: If it wraps a tensor with storage, it will have storage
- Case 2: If it wraps a tensor without storage, it will not have storage
The rationale for this is to fix pytorch/functorch#7 . When torch.tensor gets called, the
following happens:
- at::empty gets called
- some data from a PyObject* gets written directly into the new empty
tensor
The previous problem was that `at::empty` would return a TensorWrapper
wrapping a regular Tensor and that TensorWrapper did not have
storage/data_ptr.
It should be fine that TensorWrapper sometimes has storage. Users should
not write directly to the .data_ptr (because that would cause gradients
to be incorrect, but it is the same in regular PyTorch).
Test Plan:
- wait for tests
2022-07-21 13:40:57 -07:00
26503f21e1
[functorch] Install expecttest as CI step ( pytorch/functorch#66 )
2022-07-21 13:40:57 -07:00
74f773192f
[functorch] added ability for module attrs to be directly referenced in output
2022-07-21 13:40:57 -07:00
f10845eeaa
[functorch] Added binary_cross_entropy lowerings
2022-07-21 13:40:57 -07:00
379ae35ef2
[functorch] fixed device issues in tests
2022-07-21 13:40:57 -07:00
f561f0f665
[functorch] add no-dim for batchrulesreduce and move trace from old batching rule to new decomposition
2022-07-21 13:40:57 -07:00
0a8ef33ca8
[functorch] Removed no-dim rules for batchrulespooling and batchrulesviews
2022-07-21 13:40:57 -07:00
8f06e91546
[functorch] removed no bdim batching rules from batchrulesbinaryops
2022-07-21 13:40:57 -07:00
cb286b9b49
[functorch] removed no bdim cases from batchruleslinearalgebra
2022-07-21 13:40:57 -07:00
8cd80e0b16
[functorch] Added a test for case where we would previously have dispatched to a batching rule despite having no bdims
2022-07-21 13:40:57 -07:00
446d0b4e4d
[functorch] Selectively enable dispatch on kBatchedKey ( pytorch/functorch#63 )
...
This PR makes it so that dispatch on kBatchedKey only can happen if there are
tensors batched at the current level. Otherwise, kBatchedKey is excluded
(even if there are BatchedTensors!).
To find tensors batched at the current level, we check:
- all tensor arguments
- we peek into all TensorLists
- we peek into all Tensor?[].
the above bullet points should be sufficient.
Dispatch for kVmapModeKey is not affected.
Test Plan:
- run all tests
- removed the special case in dot_batch_rule and added a test
2022-07-21 13:40:57 -07:00
bbfdcdbd79
[functorch] Added einsum batching rule
2022-07-21 13:40:57 -07:00
495112f550
[functorch] updated opinfo DB and added meshgrid batching rule
2022-07-21 13:40:57 -07:00
f6667347a2
[functorch] Exclude failing tests, add xfail test for torch.tensor
2022-07-21 13:40:57 -07:00
c02cc07c96
[functorch] Add flip batching rule
2022-07-21 13:40:57 -07:00
b20e4decc4
[functorch] Added batching rule for logsumexp
2022-07-21 13:40:57 -07:00
9c138786b7
[functorch] Added norm.Scalar and norm.ScalarOpt_dim
2022-07-21 13:40:57 -07:00
65c54e282c
[functorch] Revert "Exclude List[Optional[Tensor]] from the batched fallback"
...
This reverts commit 3f5f75208aa547a2f4ed23fd281c917d327b9819.
2022-07-21 13:40:57 -07:00
6916cc5d5b
[functorch] Added full_like and refactored factory macros a bit
2022-07-21 13:40:57 -07:00
b815b5bc6b
[functorch] Added triu/tril
2022-07-21 13:40:56 -07:00
0aedd9e8c1
[functorch] finished batching rules for div
2022-07-21 13:40:56 -07:00
345cf3ebf2
[functorch] Added sort batching rules
2022-07-21 13:40:56 -07:00
94a55b7ead
[functorch] finished off where batching rules + added OP_DECOMPOSE macro
2022-07-21 13:40:56 -07:00
7b8478332b
[functorch] Added where/_s_where, changed flatten (a composite op) to decomposition, and added (var/std).correction
2022-07-21 13:40:56 -07:00
60fbba9d0e
[functorch] Added erf/inverse/isinf/isnan batching rules
2022-07-21 13:40:56 -07:00
58319030bd
[functorch] Added isnan + topk batching rules
2022-07-21 13:40:56 -07:00
7bb63ff2ef
[functorch] Added clamp.tensor batching rule
2022-07-21 13:40:56 -07:00
f787f0be9d
[functorch] Added mode batching rule
2022-07-21 13:40:56 -07:00
4f87a0694e
[functorch] Added max/min/prod (no dim)
2022-07-21 13:40:56 -07:00
85587011a8
[functorch] Exclude List[Optional[Tensor]] from the batched fallback
2022-07-21 13:40:56 -07:00
9094e757a0
[functorch] Added cumprod/cumsum and ported over log_softmax
2022-07-21 13:40:56 -07:00
ec25616e0c
[functorch] Added maximum/minimum/clamp batching rules
2022-07-21 13:40:56 -07:00
9b00f55a46
[functorch] Added bmm batch rule
2022-07-21 13:40:56 -07:00
e9c196e4f7
[functorch] Add deg2rad/rad2deg/radian batching rules
2022-07-21 13:40:56 -07:00
ede0952ee4
[functorch] add acosh/atanh/asinh batching rules
2022-07-21 13:40:56 -07:00
ed2f58242f
[functorch] fix some bugs with reduction batching rule + add amax/amin
2022-07-21 13:40:56 -07:00
43de84e088
[functorch] Added sigmoid_backward and replaced a bunch of typedefs with decltypes
2022-07-21 13:40:56 -07:00
5af36052cf
[functorch] Added atan2 and zeros_ batching rules
2022-07-21 13:40:56 -07:00