docs: fix formatting for embedding_bag (#54666)

Summary:
fixes https://github.com/pytorch/pytorch/issues/43499

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54666

Reviewed By: H-Huang

Differential Revision: D27411027

Pulled By: jbschlosser

fbshipit-source-id: a84cc174155bd725e108d8f953a21bb8de8d9d23
This commit is contained in:
Jeff Yang
2021-04-07 06:30:50 -07:00
committed by Facebook GitHub Bot
parent 6fd20a8dea
commit 263d8ef4ef
2 changed files with 43 additions and 49 deletions

View File

@ -231,7 +231,7 @@ class EmbeddingBag(Module):
EmbeddingBag also supports per-sample weights as an argument to the forward
pass. This scales the output of the Embedding before performing a weighted
reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the
only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
:attr:`per_sample_weights`.
@ -259,34 +259,6 @@ class EmbeddingBag(Module):
weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
initialized from :math:`\mathcal{N}(0, 1)`.
Inputs: :attr:`input` (IntTensor or LongTensor), :attr:`offsets` (IntTensor or LongTensor, optional), and
:attr:`per_index_weights` (Tensor, optional)
- :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
- If :attr:`input` is 2D of shape `(B, N)`,
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
:attr:`offsets` is ignored and required to be ``None`` in this case.
- If :attr:`input` is 1D of shape `(N)`,
it will be treated as a concatenation of multiple bags (sequences).
:attr:`offsets` is required to be a 1D tensor containing the
starting index positions of each bag in :attr:`input`. Therefore,
for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
having ``B`` bags. Empty bags (i.e., having 0-length) will have
returned vectors filled by zeros.
per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
must have exactly the same shape as input and is treated as having the same
:attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
Output shape: `(B, embedding_dim)`
Examples::
>>> # an Embedding module containing 10 tensors of size 3
@ -336,6 +308,36 @@ class EmbeddingBag(Module):
init.normal_(self.weight)
def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
"""Forward pass of EmbeddingBag.
Args:
input (Tensor): Tensor containing bags of indices into the embedding matrix.
offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
the starting index position of each bag (sequence) in :attr:`input`.
per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
must have exactly the same shape as input and is treated as having the same
:attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
Returns:
Tensor output shape of `(B, embedding_dim)`.
.. note::
A few notes about ``input`` and ``offsets``:
- :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
- If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
each of fixed length ``N``, and this will return ``B`` values aggregated in a way
depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
- If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the
starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`,
:attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have
returned vectors filled by zeros.
"""
return F.embedding_bag(input, self.weight, offsets,
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode, self.sparse,