added more gotchas + a Common Test Failures portion

Jane (Yuan) Xu
2022-10-26 10:54:18 -04:00
parent 642840de90
commit 42957dc8e5

@ -100,3 +100,17 @@ Here are a few of them that are important for the lab above.
- Tensor full of zeros, `None` in python and undefined Tensors in c++ all mean the same thing for gradients. This means that your backward function need to properly handle potential None/undefined Tensors and behave as-if they were Tensors full of zeros. Similarly, your backward can return None/undefined Tensors instead of a Tensor full of zeros if needed.
- Don't forget to use `ctx.set_materialize_grads()` described in the extending doc on your custom Function to prevent zero Tensors from being materialized.
- The dtype that the backward functions support might be different than the ones that the forward supports for some already defined functions. OpInfo provides many options to specify dtypes: `dtypes`, `dtypesIfCUDA`, `backward_dtypes` and `backward_dtypesIfCUDA`.
- Remember to handle `__torch_function__` for your `attn` op and add an entry in `get_testing_overrides()` in `torch/overrides.py`. Feel free to use one of the other ops as an example.
- PyTorch uses symbolic ints now, so for any `size()` or `sizes()` call you make in C++, replace them with `sym_size()` and `sym_sizes()` to handle SymInts appropriately.
- When making attn CompositeExplicit, forward AD and forward over reverse will no longer be derived for you. OpInfo provides options to specify `supports_forward_ad` and `supports_fwgrad_bwgrad`.
## Common Test Failures
It is always better to test locally first before submitting a PR. A few relevant test files you should verify before submitting your PR are: `test_ops.py`, `test_overrides.py`, `functorch/test_aotdispatch.py`, `functorch/test_ops.py`, `functorch/test_vmap.py`, `test_proxy_tensor.py`. To run only the ones with `attn` in name, use:
```
python test/<test_file.py> -k attn
```
- You can add an xfail decorator for:
- functorch tests that fail with `hit the vmap fallback which is currently disabled`
- tests requiring forward AD to be implemented _when you turn your op from CompositeImplicit to CompositeExplicit_, since the forward AD will no longer be derived for you.
- tests requiring decompositions will fail with messages like `aten.attn.default - couldn't find symbolic meta function/decomposition` _when you turn your op from CompositeImplicit to CompositeExplicit_, since those won't be derived for you anymore either.
- You can skip `attn` in the fake_autocast_device tests for CPU _when you turn your op from CompositeImplicit to CompositeExplicit_.