mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix the example in the extending.func.rst (#109279)
As the title shown ,the `backward` function is missing the definition of `ind` and `ind_inv`, which will lead to error when calling backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/109279 Approved by: https://github.com/zou3519
This commit is contained in:
@ -386,6 +386,7 @@ Example::
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, _0, _1):
|
||||
ind, ind_inv = ctx.saved_tensors
|
||||
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
|
||||
|
||||
# The signature of the vmap staticmethod is:
|
||||
|
Reference in New Issue
Block a user