mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
Add sparse gradient option to gather operation (#17182)
Summary: This PR allows `gather` to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so. I've commented out size.size() check in `SparseTensor.cpp` that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required. Motivating example: For this commonly used label smoothing loss function ``` def label_smoothing_opt(x, target): padding_idx = 0 smoothing = 0.1 logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32) pad_mask = (target == padding_idx) ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1) smooth_loss = logprobs.mean(dim=-1) loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss loss.masked_fill_(pad_mask, 0) return loss.sum() ``` backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required. Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now! cc gchanan apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/17182 Differential Revision: D14158431 Pulled By: gchanan fbshipit-source-id: c8b654611534198025daaf7a634482b3151fbade
This commit is contained in:
committed by
Facebook Github Bot
parent
a2b9f7f484
commit
b4572668b4
@ -1844,7 +1844,7 @@ Example::
|
||||
|
||||
add_docstr(torch.gather,
|
||||
r"""
|
||||
gather(input, dim, index, out=None) -> Tensor
|
||||
gather(input, dim, index, out=None, sparse_grad=False) -> Tensor
|
||||
|
||||
Gathers values along an axis specified by `dim`.
|
||||
|
||||
@ -1865,6 +1865,7 @@ Args:
|
||||
dim (int): the axis along which to index
|
||||
index (LongTensor): the indices of elements to gather
|
||||
out (Tensor, optional): the destination tensor
|
||||
sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
Reference in New Issue
Block a user