Fix the example in the extending.func.rst (#109279)

As the title shown ,the `backward` function is missing the definition of `ind` and `ind_inv`, which will lead to error when calling backward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109279
Approved by: https://github.com/zou3519
This commit is contained in:
FFFrog
2023-09-14 17:29:35 +00:00
committed by PyTorch MergeBot
parent 9021fb8dac
commit d4990ad5a1

View File

@ -386,6 +386,7 @@ Example::
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is: