Support modules that output scalar in Gather (and data parallel) (#7973)

* Support modules that output scalar in Gather (and data parallel)

* Improve warning msg
This commit is contained in:
Tongzhou Wang
2018-06-01 16:20:39 -04:00
committed by GitHub
parent 215abffe60
commit c6a923f486
3 changed files with 43 additions and 3 deletions

View File

@ -61,6 +61,12 @@ class DataParallel(Module):
that each such hook be executed before the corresponding
:meth:`~torch.nn.Module.forward` call of that device.
.. warning::
When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
:func:`forward`, this wrapper will return a vector of length equal to
number of devices used in data parallelism, containing the result from
each device.
.. note::
There is a subtlety in using the
``pack sequence -> recurrent network -> unpack sequence`` pattern in a