Files
pytorch/torch/legacy/nn/Index.py
Tongzhou Wang 1c01eabd3c Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00

26 lines
699 B
Python

import torch
from .Module import Module
class Index(Module):
def __init__(self, dimension):
super(Index, self).__init__()
self.dimension = dimension
self.gradInput = [self.gradInput]
def updateOutput(self, input):
t = input[0]
index = input[1]
torch.index_select(t, self.dimension, index, out=self.output)
return self.output
def updateGradInput(self, input, gradOutput):
t = input[0]
index = input[1]
gradInput = self.gradInput[0] # no gradient for the index tensor
gradInput.resize_as_(t).zero_()
gradInput.index_add_(self.dimension, index, gradOutput)
return self.gradInput