Customized pin_memory for PackedSequence (#18079)

Summary:
fixes https://github.com/pytorch/pytorch/issues/18078
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18079

Reviewed By: ezyang

Differential Revision: D14521192

Pulled By: zou3519

fbshipit-source-id: cec773a3a6f2c405a0d9701e213b7caf81649181
This commit is contained in:
Tongzhou Wang
2019-03-19 13:35:55 -07:00
committed by Facebook Github Bot
parent 916a670828
commit f212fd9fd6
3 changed files with 116 additions and 16 deletions

View File

@ -8,6 +8,11 @@ void checkLongTensor(const Tensor& tensor) {
"'lengths' argument should be a 1D CPU int64 tensor"); "'lengths' argument should be a 1D CPU int64 tensor");
} }
// This method returns `(data, batch_sizes)`, which are then passed into a
// `PackedSequence` constructor.
// `data` can be on arbitrary device and of arbitrary dtype, but `batch_sizes`
// must be a CPU int64 tensor.
// See NOTE [ device and dtype of a PackedSequence ]
std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Tensor& _lengths, bool batch_first) { std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Tensor& _lengths, bool batch_first) {
auto input = batch_first ? _input.transpose(0, 1) : _input; auto input = batch_first ? _input.transpose(0, 1) : _input;
auto lengths_t = _lengths.contiguous(); auto lengths_t = _lengths.contiguous();
@ -84,6 +89,9 @@ std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Ten
return std::make_tuple(at::cat(steps), batch_sizes_t); return std::make_tuple(at::cat(steps), batch_sizes_t);
} }
// `grad` could be on arbitrary device and of arbitrary dtype, but `_batch_sizes`
// is guaranteed to be a CPU int64 tensor.
// See NOTE [ device and dtype of a PackedSequence ]
Tensor _pack_padded_sequence_backward(const Tensor& grad, at::IntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) { Tensor _pack_padded_sequence_backward(const Tensor& grad, at::IntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) {
std::vector<int64_t> input_size_after_t = input_size.vec(); std::vector<int64_t> input_size_after_t = input_size.vec();
if (batch_first) { if (batch_first) {

View File

@ -1004,7 +1004,7 @@ class TestNamedTupleDataLoader(TestCase):
self.assertIsInstance(batch.data, NamedTupleDataset.Data) self.assertIsInstance(batch.data, NamedTupleDataset.Data)
class SimpleCustomBatch: class SimpleCustomBatch(object):
def __init__(self, data): def __init__(self, data):
transposed_data = list(zip(*data)) transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0) self.inp = torch.stack(transposed_data[0], 0)
@ -1015,11 +1015,28 @@ class SimpleCustomBatch:
self.tgt = self.tgt.pin_memory() self.tgt = self.tgt.pin_memory()
return self return self
def is_pinned(self):
return self.inp.is_pinned() and self.tgt.is_pinned()
def collate_wrapper(batch): def collate_wrapper(batch):
return SimpleCustomBatch(batch) return SimpleCustomBatch(batch)
def collate_into_packed_sequence(batch):
data = torch.stack([sample[0] for sample in batch], 1)
t, b = data.size()
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
def collate_into_packed_sequence_batch_first(batch):
data = torch.stack([sample[0] for sample in batch], 0)
b, t = data.size()
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, batch_first=True, enforce_sorted=False)
class TestCustomPinFn(TestCase): class TestCustomPinFn(TestCase):
def setUp(self): def setUp(self):
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
@ -1029,20 +1046,32 @@ class TestCustomPinFn(TestCase):
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm @skipIfRocm
def test_custom_batch_pin(self): def test_custom_batch_pin(self):
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, test_cases = [
pin_memory=True) (collate_wrapper, SimpleCustomBatch),
for sample in loader: (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
self.assertTrue(sample.inp.is_pinned()) (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
self.assertTrue(sample.tgt.is_pinned()) ]
for collate_fn, elem_cls in test_cases:
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
pin_memory=True)
for sample in loader:
self.assertIsInstance(sample, elem_cls)
self.assertTrue(sample.is_pinned())
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm @skipIfRocm
def test_custom_batch_pin_worker(self): def test_custom_batch_pin_worker(self):
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, test_cases = [
pin_memory=True, num_workers=1) (collate_wrapper, SimpleCustomBatch),
for sample in loader: (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
self.assertTrue(sample.inp.is_pinned()) (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
self.assertTrue(sample.tgt.is_pinned()) ]
for collate_fn, elem_cls in test_cases:
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
pin_memory=True, num_workers=1)
for sample in loader:
self.assertIsInstance(sample, elem_cls)
self.assertTrue(sample.is_pinned())
class TestWorkerQueueDataset(Dataset): class TestWorkerQueueDataset(Dataset):

View File

@ -33,8 +33,29 @@ class PackedSequence(PackedSequence_):
data (Tensor): Tensor containing packed sequence data (Tensor): Tensor containing packed sequence
batch_sizes (Tensor): Tensor of integers holding batch_sizes (Tensor): Tensor of integers holding
information about the batch size at each sequence step information about the batch size at each sequence step
sorted_indices (Tensor, optional): Tensor of integers holding how this
:class:`PackedSequence` is constructed from sequences.
unsorted_indices (Tensor, optional): Tensor of integers holding how this
to recover the original sequences with correct order.
.. note::
:attr:`data` can be on arbitrary device and of arbitrary dtype.
:attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
tensors on the same device as :attr:`data`.
However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
This invariant is maintained throughout :class:`PackedSequence` class,
and all functions that construct a `:class:PackedSequence` in PyTorch
(i.e., they only pass in tensors conforming to this constraint).
""" """
# NOTE [ device and dtype of a PackedSequence ]
#
# See the note above in doc string (starting with ":attr:`data` can be on
# arbitrary device...").
def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None): def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
# PackedSequence used to only have __init__(self, data, batch_sizes) # PackedSequence used to only have __init__(self, data, batch_sizes)
# without a __new__ like this. So to preserve BC for calling in keyword # without a __new__ like this. So to preserve BC for calling in keyword
@ -58,11 +79,20 @@ class PackedSequence(PackedSequence_):
return super(PackedSequence, cls).__new__( return super(PackedSequence, cls).__new__(
cls, data[0], data[1], sorted_indices) cls, data[0], data[1], sorted_indices)
def pin_memory(self):
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.pin_memory(), self.batch_sizes,
bind(self.sorted_indices, lambda t: t.pin_memory()),
bind(self.unsorted_indices, lambda t: t.pin_memory()))
def cuda(self, *args, **kwargs): def cuda(self, *args, **kwargs):
"""Returns a GPU copy if `self.data` not already on the GPU""" """Returns a GPU copy if `self.data` not already on the GPU"""
if self.is_cuda: if self.is_cuda:
return self return self
else: else:
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes, return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes,
bind(self.sorted_indices, lambda t: t.cuda(*args, **kwargs)), bind(self.sorted_indices, lambda t: t.cuda(*args, **kwargs)),
bind(self.unsorted_indices, lambda t: t.cuda(*args, **kwargs))) bind(self.unsorted_indices, lambda t: t.cuda(*args, **kwargs)))
@ -70,6 +100,8 @@ class PackedSequence(PackedSequence_):
def cpu(self): def cpu(self):
"""Returns a CPU copy if `self.data` not already on the CPU""" """Returns a CPU copy if `self.data` not already on the CPU"""
if self.is_cuda: if self.is_cuda:
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.cpu(), self.batch_sizes, return type(self)(self.data.cpu(), self.batch_sizes,
bind(self.sorted_indices, lambda t: t.cpu()), bind(self.sorted_indices, lambda t: t.cpu()),
bind(self.unsorted_indices, lambda t: t.cpu())) bind(self.unsorted_indices, lambda t: t.cpu()))
@ -78,41 +110,65 @@ class PackedSequence(PackedSequence_):
def double(self): def double(self):
r"""Returns copy with `self.data` cast to double type""" r"""Returns copy with `self.data` cast to double type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.double(), self.batch_sizes, return type(self)(self.data.double(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
def float(self): def float(self):
r"""Returns copy with `self.data` cast to float type""" r"""Returns copy with `self.data` cast to float type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.float(), self.batch_sizes, return type(self)(self.data.float(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
def half(self): def half(self):
r"""Returns copy with `self.data` cast to half type""" r"""Returns copy with `self.data` cast to half type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.half(), self.batch_sizes, return type(self)(self.data.half(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
def long(self): def long(self):
r"""Returns copy with `self.data` cast to long type""" r"""Returns copy with `self.data` cast to long type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.long(), self.batch_sizes, return type(self)(self.data.long(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
def int(self): def int(self):
r"""Returns copy with `self.data` cast to int type""" r"""Returns copy with `self.data` cast to int type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.int(), self.batch_sizes, return type(self)(self.data.int(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
def short(self): def short(self):
r"""Returns copy with `self.data` cast to short type""" r"""Returns copy with `self.data` cast to short type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.short(), self.batch_sizes, return type(self)(self.data.short(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
def char(self): def char(self):
r"""Returns copy with `self.data` cast to char type""" r"""Returns copy with `self.data` cast to char type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.char(), self.batch_sizes, return type(self)(self.data.char(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
def byte(self): def byte(self):
r"""Returns copy with `self.data` cast to byte type""" r"""Returns copy with `self.data` cast to byte type"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(self.data.byte(), self.batch_sizes, return type(self)(self.data.byte(), self.batch_sizes,
self.sorted_indices, self.unsorted_indices) self.sorted_indices, self.unsorted_indices)
@ -127,6 +183,9 @@ class PackedSequence(PackedSequence_):
and :class:`torch.device`, then ``self`` is returned. and :class:`torch.device`, then ``self`` is returned.
Otherwise, returns a copy with the desired configuration. Otherwise, returns a copy with the desired configuration.
""" """
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
data = self.data.to(*args, **kwargs) data = self.data.to(*args, **kwargs)
sorted_indices = self.sorted_indices sorted_indices = self.sorted_indices
unsorted_indices = self.unsorted_indices unsorted_indices = self.unsorted_indices
@ -145,6 +204,10 @@ class PackedSequence(PackedSequence_):
r"""Returns true if `self.data` stored on a gpu""" r"""Returns true if `self.data` stored on a gpu"""
return self.data.is_cuda return self.data.is_cuda
def is_pinned(self):
r"""Returns true if `self.data` stored on in pinned memory"""
return self.data.is_pinned()
def invert_permutation(permutation): def invert_permutation(permutation):
if permutation is None: if permutation is None:
@ -158,12 +221,12 @@ def invert_permutation(permutation):
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True): def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
r"""Packs a Tensor containing padded sequences of variable length. r"""Packs a Tensor containing padded sequences of variable length.
Input can be of size ``T x B x *`` where `T` is the length of the longest sequence :attr:`input` can be of size ``T x B x *`` where `T` is the length of the
(equal to ``lengths[0]``), `B` is the batch size, and `*` is any number of longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
dimensions (including 0). If ``batch_first`` is True ``B x T x *`` inputs are ``*`` is any number of dimensions (including 0). If ``batch_first`` is
expected. ``True``, ``B x T x *`` :attr:`input` is expected.
For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` is For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
``True``, the sequences should be sorted by length in a decreasing order, i.e. ``True``, the sequences should be sorted by length in a decreasing order, i.e.
``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
one. `enforce_sorted = True` is only necessary for ONNX export. one. `enforce_sorted = True` is only necessary for ONNX export.