Add sparse gradient option to gather operation (#17182)

Summary:
This PR allows `gather` to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so.
I've commented out size.size() check in `SparseTensor.cpp` that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required.
Motivating example:
For this commonly used label smoothing loss function
```
def label_smoothing_opt(x, target):
    padding_idx = 0
    smoothing = 0.1
    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
    pad_mask = (target == padding_idx)
    ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1)
    smooth_loss = logprobs.mean(dim=-1)
    loss =  (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
    loss.masked_fill_(pad_mask, 0)
    return loss.sum()
```
backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required.
Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now!

cc gchanan apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17182

Differential Revision: D14158431

Pulled By: gchanan

fbshipit-source-id: c8b654611534198025daaf7a634482b3151fbade
This commit is contained in:
Natalia Gimelshein
2019-02-27 11:39:37 -08:00
committed by Facebook Github Bot
parent a2b9f7f484
commit b4572668b4
12 changed files with 81 additions and 14 deletions

View File

@ -1844,7 +1844,7 @@ Example::
add_docstr(torch.gather,
r"""
gather(input, dim, index, out=None) -> Tensor
gather(input, dim, index, out=None, sparse_grad=False) -> Tensor
Gathers values along an axis specified by `dim`.
@ -1865,6 +1865,7 @@ Args:
dim (int): the axis along which to index
index (LongTensor): the indices of elements to gather
out (Tensor, optional): the destination tensor
sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor.
Example::