168 Commits

Author SHA1 Message Date
11fe277b62 [PrimTorch] Add reference for torch.norm (#81765)
This ref does more things than `torch.norm`, and it fixes a few bugs
that `torch.norm` has. This implementation and the `torch.norm`
implementation come to terms in the next PR of this stack

We put this PR before, as otherwise `test_decomp.py` was failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81765
Approved by: https://github.com/ngimel
2022-07-25 19:57:21 +00:00
f595467e5c Reenable slow gradcheck and make it pass (#80514)
Context: For a while slow gradcheck CI was skipping nearly all tests and this hid the fact that it should've been failing and timing out (10+h runtime for TestGradients). The CI configuration has since been fixed to correct this, revealing the test failures. This PR reenables slow gradcheck CI and makes it pass again.

This PR:
- makes slow and failing tests run in fast gradcheck mode only
- reduce the input size for slow gradcheck only for unary/binary ufuncs (alternatively, skip the test entirely)
- skip entire test files on slow gradcheck runner if they don't use gradcheck (test_ops, test_meta, test_decomp, test_ops_jit)
- reduces the input size for some ops

Follow ups:
1. Investigate slow mode failures https://github.com/pytorch/pytorch/issues/80411
2. See if we can re-enable slow gradcheck tests for some of the slow tests by reducing the sizes of their inputs

The following are failing in slow mode, they are now running in fast mode only.
```
test_fn_fwgrad_bwgrad___rmod___cuda_float64
test_fn_fwgrad_bwgrad_linalg_householder_product_cuda_complex128
test_fn_fwgrad_bwgrad__masked_prod_cuda_complex128
test_fn_fwgrad_bwgrad__masked_prod_cuda_float64
test_fn_fwgrad_bwgrad_linalg_matrix_power_cuda_complex128
test_fn_fwgrad_bwgrad_cat_cuda_complex128
test_fn_fwgrad_bwgrad_linalg_lu_factor_ex_cuda_float64
test_fn_fwgrad_bwgrad_copysign_cuda_float64
test_fn_fwgrad_bwgrad_cholesky_inverse_cuda_complex128
test_fn_fwgrad_bwgrad_float_power_cuda_complex128
test_fn_fwgrad_bwgrad_fmod_cuda_float64
test_fn_fwgrad_bwgrad_float_power_cuda_float64
test_fn_fwgrad_bwgrad_linalg_lu_cuda_float64
test_fn_fwgrad_bwgrad_remainder_cuda_float64
test_fn_fwgrad_bwgrad_repeat_cuda_complex128
test_fn_fwgrad_bwgrad_prod_cuda_complex128
test_fn_fwgrad_bwgrad_slice_scatter_cuda_float64
test_fn_fwgrad_bwgrad_tile_cuda_complex128
test_fn_fwgrad_bwgrad_pow_cuda_float64
test_fn_fwgrad_bwgrad_pow_cuda_complex128
test_fn_fwgrad_bwgrad_fft_*
test_fn_fwgrad_bwgrad_zero__cuda_complex128
test_fn_gradgrad_linalg_lu_factor_cuda_float64
test_fn_grad_div_trunc_rounding_cuda_float64
test_fn_grad_div_floor_rounding_cuda_float64
```

Marks the OpInfos for the following ops that run slowly in slow gradcheck as `fast_gradcheck` only (the left column represents runtime in seconds):
```
0  918.722  test_fn_fwgrad_bwgrad_nn_functional_conv_transpose3d_cuda_float64
1  795.042  test_fn_fwgrad_bwgrad_nn_functional_unfold_cuda_complex128
2  583.63  test_fn_fwgrad_bwgrad_nn_functional_max_pool3d_cuda_float64
3  516.946  test_fn_fwgrad_bwgrad_svd_cuda_complex128
4  503.179  test_fn_fwgrad_bwgrad_linalg_svd_cuda_complex128
5  460.985  test_fn_fwgrad_bwgrad_linalg_lu_cuda_complex128
6  401.04  test_fn_fwgrad_bwgrad_linalg_lstsq_grad_oriented_cuda_complex128
7  353.671  test_fn_fwgrad_bwgrad_nn_functional_max_pool2d_cuda_float64
8  321.903  test_fn_fwgrad_bwgrad_nn_functional_gaussian_nll_loss_cuda_float64
9  307.951  test_fn_fwgrad_bwgrad_stft_cuda_complex128
10  266.104  test_fn_fwgrad_bwgrad_svd_lowrank_cuda_float64
11  221.032  test_fn_fwgrad_bwgrad_istft_cuda_complex128
12  183.741  test_fn_fwgrad_bwgrad_lu_unpack_cuda_complex128
13  132.019  test_fn_fwgrad_bwgrad_nn_functional_unfold_cuda_float64
14  125.343  test_fn_fwgrad_bwgrad_nn_functional_pad_constant_cuda_complex128
15  124.2  test_fn_fwgrad_bwgrad_kron_cuda_complex128
16  123.721  test_fn_fwgrad_bwgrad_pca_lowrank_cuda_float64
17  121.074  test_fn_fwgrad_bwgrad_nn_functional_max_unpool3d_cuda_float64
18  119.387  test_fn_fwgrad_bwgrad_rot90_cuda_complex128
19  112.889  test_fn_fwgrad_bwgrad__masked_normalize_cuda_complex128
20  107.541  test_fn_fwgrad_bwgrad_dist_cuda_complex128
21  106.727  test_fn_fwgrad_bwgrad_diff_cuda_complex128
22  104.588  test_fn_fwgrad_bwgrad__masked_cumprod_cuda_complex128
23  100.135  test_fn_fwgrad_bwgrad_nn_functional_feature_alpha_dropout_with_train_cuda_float64
24  88.359  test_fn_fwgrad_bwgrad_mH_cuda_complex128
25  86.214  test_fn_fwgrad_bwgrad_nn_functional_max_unpool2d_cuda_float64
26  83.037  test_fn_fwgrad_bwgrad_nn_functional_bilinear_cuda_float64
27  79.987  test_fn_fwgrad_bwgrad__masked_cumsum_cuda_complex128
28  77.822  test_fn_fwgrad_bwgrad_diag_embed_cuda_complex128
29  76.256  test_fn_fwgrad_bwgrad_mT_cuda_complex128
30  74.039  test_fn_fwgrad_bwgrad_linalg_lu_solve_cuda_complex128
```
```
0  334.142  test_fn_fwgrad_bwgrad_unfold_cuda_complex128
1  312.791  test_fn_fwgrad_bwgrad_linalg_lu_factor_cuda_complex128
2  121.963  test_fn_fwgrad_bwgrad_nn_functional_max_unpool3d_cuda_float64
3  108.085  test_fn_fwgrad_bwgrad_diff_cuda_complex128
4  89.418  test_fn_fwgrad_bwgrad_nn_functional_max_unpool2d_cuda_float64
5  72.231  test_fn_fwgrad_bwgrad___rdiv___cuda_complex128
6  69.433  test_fn_fwgrad_bwgrad___getitem___cuda_complex128
7  68.582  test_fn_fwgrad_bwgrad_ldexp_cuda_complex128
8  68.572  test_fn_fwgrad_bwgrad_linalg_pinv_cuda_complex128
9  67.585  test_fn_fwgrad_bwgrad_nn_functional_glu_cuda_float64
10  66.567  test_fn_fwgrad_bwgrad_lu_cuda_float64
```
```
0  630.13  test_fn_gradgrad_nn_functional_conv2d_cuda_complex128
1  81.086  test_fn_gradgrad_linalg_solve_triangular_cuda_complex128
2  71.332  test_fn_gradgrad_norm_cuda_complex128
3  64.308  test_fn_gradgrad__masked_std_cuda_complex128
4  59.519  test_fn_gradgrad_div_no_rounding_mode_cuda_complex128
5  58.836  test_fn_gradgrad_nn_functional_adaptive_avg_pool3
```

Reduces the sizes of the inputs for:
- diff
- diag_embed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80514
Approved by: https://github.com/albanD
2022-07-22 02:05:37 +00:00
6cf0d9249f Add prelu and relu6 refs missing from __all__ and decomp db (#81420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81420
Approved by: https://github.com/ngimel
2022-07-19 20:50:04 +00:00
bc1fef96af Reference implementations for rsqrt and native_layer_norm (#79413)
This PR adds references for:
- `torch.rsqrt`
- `torch.native_layer_norm`
-  `torch.nn.functional.layer_norm`

`native_layer_norm` had a different number of dimensions if the input was 0-sized. I fixed that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79413
Approved by: https://github.com/mruberry, https://github.com/Chillee
2022-06-17 07:24:02 +00:00
fd37d1d870 Minor updates for upcast_tensor
Based on comments in https://github.com/pytorch/pytorch/pull/78459

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78525

Approved by: https://github.com/albanD
2022-06-01 00:46:52 +00:00
eee2aa14a6 Register std_mean ref as a decomposition
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78468

Approved by: https://github.com/ngimel
2022-05-31 18:59:16 +00:00
678213ead2 Fake Tensor Part 1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77969

Approved by: https://github.com/ezyang
2022-05-31 16:20:35 +00:00
5620ebad5f Unconditionally transform dtype arguments to double for upcast
This ensures we can also hit functions like torch.sum that take
a dtype argument, without having to manually list them all.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78459

Approved by: https://github.com/albanD
2022-05-31 14:18:49 +00:00
64b4bb4b01 Fix meta tests on norm (and relanding norm fixes) (#77930)
Had a land race with meta tests.

Will also be relanding https://github.com/pytorch/pytorch/pull/77407
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77930
Approved by: https://github.com/malfet, https://github.com/ezyang
2022-05-20 23:15:53 +00:00
03546e9c07 Revert "Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)"
This reverts commit 70d80fb42480f4df7bd369f8f9f1500b58c5c603.

Reverted https://github.com/pytorch/pytorch/pull/77407 on behalf of https://github.com/malfet due to as it broke meta tests ( I guess due to landrace), see 70d80fb424
2022-05-20 02:31:57 +00:00
70d80fb424 Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)
Originally, when these were written, they simply used the naive strategy of "upcast all inputs to floats, and downcast all inputs back". In addition to being... not quite what the kernels did, they also didn't capture some additional semantics. Namely, that the norms (except for layer norm on CPU! cc: @ngimel) return fp32 for the mean and rstd values.

Also, folks didn't like that I wrote `native_layer_norm` in terms of `native_batch_norm`. Which is fair - so I refactored the common logic into a `normalize` function.

cc: @jansel / @bertmaher , who've been looking at lowering layer norm/batch norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77407
Approved by: https://github.com/bertmaher
2022-05-19 17:11:47 +00:00
0a14a4c280 Register prims as operators.
This makes prims look as if they were defined in native_functions.yaml
but they're still all written in Python.  You now need to give a full
schema string for your prims.  The returned prim object is now
torch.ops.prim overload (prims are not allowed to be overloaded,
so we return the overload, not the overload packet, for speed.)

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77117

Approved by: https://github.com/mruberry, https://github.com/albanD
2022-05-11 16:38:14 +00:00
f2eed9400d Register PrimTorch refs as decompositions.
For the most part, PrimTorch refs have the same signature as their
ATen equivalents.  I modify most PrimTorch refs to register themselves
as decompositions, using the prim name they wrap to find the aten name
(except for a few cases where the prim/aten names mismatch).  There are
some exclusions, falling into one of two categories:

- The torch equivalent was already implemented as a CompositeImplicitAutograd
  decomposition in C++

- The ref doesn't support enough features (e.g., the real deal has more
  kwargs / overloads than are currently implemented)

PrimTorch refs are written as a single function that supports all
overloads, and this style is convenient for cases where we have a bundle
of overloads for what morally is a single overload with a Union type
on an argument (which we ought to have supported in
native_functions.yaml but blah); to support registering a single decomp
for all the overloads, we modify register_decomposition to register
to ALL overloads if you pass it an overload packet.  This is technically
BC breaking but no tests started failing because of it.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76835

Approved by: https://github.com/Chillee, https://github.com/mruberry
2022-05-06 20:11:45 +00:00
6917034afb Added logit/reciprocal decomps, fixed var for complex, moved type promotion logic to standardize on primtorch's
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76633
Approved by: https://github.com/ezyang
2022-05-04 21:29:52 +00:00
ed18181d83 Added gelu decomposition
^
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76763
Approved by: https://github.com/ezyang
2022-05-03 23:23:18 +00:00
598e7e5f19 [Reland] Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519, https://github.com/albanD
2022-05-02 20:06:43 +00:00
fb24614011 Port functorch decomps over and fix some tests
Still some stuff to fix up, will finish later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76621
Approved by: https://github.com/ezyang
2022-05-01 08:48:48 +00:00
a3f10ec281 Move functorch decompositions to PyTorch
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76311

Approved by: https://github.com/Chillee
2022-04-30 16:47:53 +00:00