mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable codegen of per-dispatch key derivative formulas in derivatives.yaml (#82801)
`derivatives.yaml` can now take a `dispatch` entry which registers per-autograd dispatch key derivatives such as
```
name: foo(Tensor self, Tensor y) -> Tensor
dispatch:
Default:
x: grad
y: grad.expand(y.sizes())
AutogradNestedTensor:
x: grad
y: NestedTensor_foo_backward(grad, y)
output_differentiabilty: [True]
```
However the old schema where there is no `dispatch` entry is still supported.
Would greatly appreciate feedback on *how to improve the testing strategy* of this PR, currently have registered an aten test op in TestOps.cpp with dummy gradients in derivatives.yaml and have some tests in test_autograd.py:TestAutogradMultipleDispatch but I am not sure whether these are sufficiently rigorous.
Additionally, this PR also makes the assumption that sets like [VIEW_FUNCTIONS](ff5399e528/tools/autograd/gen_inplace_or_view_type.py (L60)
) are per-native-function and not per-native-function-and-dispatch-key. I'm not sure whether this is necessarily the case, *would there ever be a situation where (e.g. a nested_tensor op is a view op but the aten function is not or vice versa?)*
* __->__ #82801
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82801
Approved by: https://github.com/bhosmer, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4ea751810
commit
e3e33cfae0
@ -16,3 +16,15 @@ def with_native_function_with_differentiability_info(
|
||||
return func(f)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Like the above but with an additional dispatch key string argument
|
||||
def with_native_function_with_differentiability_info_and_key(
|
||||
func: Callable[[NFWDI, str], T]
|
||||
) -> Callable[[NFWDI, str], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: NFWDI, key: str) -> T:
|
||||
with native_function_manager(f.func):
|
||||
return func(f, key)
|
||||
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user