mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
Customized pin_memory for PackedSequence (#18079)
Summary: fixes https://github.com/pytorch/pytorch/issues/18078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18079 Reviewed By: ezyang Differential Revision: D14521192 Pulled By: zou3519 fbshipit-source-id: cec773a3a6f2c405a0d9701e213b7caf81649181
This commit is contained in:
committed by
Facebook Github Bot
parent
916a670828
commit
f212fd9fd6
@ -8,6 +8,11 @@ void checkLongTensor(const Tensor& tensor) {
|
|||||||
"'lengths' argument should be a 1D CPU int64 tensor");
|
"'lengths' argument should be a 1D CPU int64 tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This method returns `(data, batch_sizes)`, which are then passed into a
|
||||||
|
// `PackedSequence` constructor.
|
||||||
|
// `data` can be on arbitrary device and of arbitrary dtype, but `batch_sizes`
|
||||||
|
// must be a CPU int64 tensor.
|
||||||
|
// See NOTE [ device and dtype of a PackedSequence ]
|
||||||
std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Tensor& _lengths, bool batch_first) {
|
std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Tensor& _lengths, bool batch_first) {
|
||||||
auto input = batch_first ? _input.transpose(0, 1) : _input;
|
auto input = batch_first ? _input.transpose(0, 1) : _input;
|
||||||
auto lengths_t = _lengths.contiguous();
|
auto lengths_t = _lengths.contiguous();
|
||||||
@ -84,6 +89,9 @@ std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Ten
|
|||||||
return std::make_tuple(at::cat(steps), batch_sizes_t);
|
return std::make_tuple(at::cat(steps), batch_sizes_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `grad` could be on arbitrary device and of arbitrary dtype, but `_batch_sizes`
|
||||||
|
// is guaranteed to be a CPU int64 tensor.
|
||||||
|
// See NOTE [ device and dtype of a PackedSequence ]
|
||||||
Tensor _pack_padded_sequence_backward(const Tensor& grad, at::IntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) {
|
Tensor _pack_padded_sequence_backward(const Tensor& grad, at::IntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) {
|
||||||
std::vector<int64_t> input_size_after_t = input_size.vec();
|
std::vector<int64_t> input_size_after_t = input_size.vec();
|
||||||
if (batch_first) {
|
if (batch_first) {
|
||||||
|
|||||||
@ -1004,7 +1004,7 @@ class TestNamedTupleDataLoader(TestCase):
|
|||||||
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
|
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
|
||||||
|
|
||||||
|
|
||||||
class SimpleCustomBatch:
|
class SimpleCustomBatch(object):
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
transposed_data = list(zip(*data))
|
transposed_data = list(zip(*data))
|
||||||
self.inp = torch.stack(transposed_data[0], 0)
|
self.inp = torch.stack(transposed_data[0], 0)
|
||||||
@ -1015,11 +1015,28 @@ class SimpleCustomBatch:
|
|||||||
self.tgt = self.tgt.pin_memory()
|
self.tgt = self.tgt.pin_memory()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def is_pinned(self):
|
||||||
|
return self.inp.is_pinned() and self.tgt.is_pinned()
|
||||||
|
|
||||||
|
|
||||||
def collate_wrapper(batch):
|
def collate_wrapper(batch):
|
||||||
return SimpleCustomBatch(batch)
|
return SimpleCustomBatch(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_into_packed_sequence(batch):
|
||||||
|
data = torch.stack([sample[0] for sample in batch], 1)
|
||||||
|
t, b = data.size()
|
||||||
|
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
|
||||||
|
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_into_packed_sequence_batch_first(batch):
|
||||||
|
data = torch.stack([sample[0] for sample in batch], 0)
|
||||||
|
b, t = data.size()
|
||||||
|
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
|
||||||
|
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
|
||||||
|
|
||||||
class TestCustomPinFn(TestCase):
|
class TestCustomPinFn(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
||||||
@ -1029,20 +1046,32 @@ class TestCustomPinFn(TestCase):
|
|||||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_custom_batch_pin(self):
|
def test_custom_batch_pin(self):
|
||||||
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper,
|
test_cases = [
|
||||||
pin_memory=True)
|
(collate_wrapper, SimpleCustomBatch),
|
||||||
for sample in loader:
|
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
|
||||||
self.assertTrue(sample.inp.is_pinned())
|
(collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
|
||||||
self.assertTrue(sample.tgt.is_pinned())
|
]
|
||||||
|
for collate_fn, elem_cls in test_cases:
|
||||||
|
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
|
||||||
|
pin_memory=True)
|
||||||
|
for sample in loader:
|
||||||
|
self.assertIsInstance(sample, elem_cls)
|
||||||
|
self.assertTrue(sample.is_pinned())
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_custom_batch_pin_worker(self):
|
def test_custom_batch_pin_worker(self):
|
||||||
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper,
|
test_cases = [
|
||||||
pin_memory=True, num_workers=1)
|
(collate_wrapper, SimpleCustomBatch),
|
||||||
for sample in loader:
|
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
|
||||||
self.assertTrue(sample.inp.is_pinned())
|
(collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
|
||||||
self.assertTrue(sample.tgt.is_pinned())
|
]
|
||||||
|
for collate_fn, elem_cls in test_cases:
|
||||||
|
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
|
||||||
|
pin_memory=True, num_workers=1)
|
||||||
|
for sample in loader:
|
||||||
|
self.assertIsInstance(sample, elem_cls)
|
||||||
|
self.assertTrue(sample.is_pinned())
|
||||||
|
|
||||||
|
|
||||||
class TestWorkerQueueDataset(Dataset):
|
class TestWorkerQueueDataset(Dataset):
|
||||||
|
|||||||
@ -33,8 +33,29 @@ class PackedSequence(PackedSequence_):
|
|||||||
data (Tensor): Tensor containing packed sequence
|
data (Tensor): Tensor containing packed sequence
|
||||||
batch_sizes (Tensor): Tensor of integers holding
|
batch_sizes (Tensor): Tensor of integers holding
|
||||||
information about the batch size at each sequence step
|
information about the batch size at each sequence step
|
||||||
|
sorted_indices (Tensor, optional): Tensor of integers holding how this
|
||||||
|
:class:`PackedSequence` is constructed from sequences.
|
||||||
|
unsorted_indices (Tensor, optional): Tensor of integers holding how this
|
||||||
|
to recover the original sequences with correct order.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:attr:`data` can be on arbitrary device and of arbitrary dtype.
|
||||||
|
:attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
|
||||||
|
tensors on the same device as :attr:`data`.
|
||||||
|
|
||||||
|
However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
|
||||||
|
|
||||||
|
This invariant is maintained throughout :class:`PackedSequence` class,
|
||||||
|
and all functions that construct a `:class:PackedSequence` in PyTorch
|
||||||
|
(i.e., they only pass in tensors conforming to this constraint).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# NOTE [ device and dtype of a PackedSequence ]
|
||||||
|
#
|
||||||
|
# See the note above in doc string (starting with ":attr:`data` can be on
|
||||||
|
# arbitrary device...").
|
||||||
|
|
||||||
def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
|
def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
|
||||||
# PackedSequence used to only have __init__(self, data, batch_sizes)
|
# PackedSequence used to only have __init__(self, data, batch_sizes)
|
||||||
# without a __new__ like this. So to preserve BC for calling in keyword
|
# without a __new__ like this. So to preserve BC for calling in keyword
|
||||||
@ -58,11 +79,20 @@ class PackedSequence(PackedSequence_):
|
|||||||
return super(PackedSequence, cls).__new__(
|
return super(PackedSequence, cls).__new__(
|
||||||
cls, data[0], data[1], sorted_indices)
|
cls, data[0], data[1], sorted_indices)
|
||||||
|
|
||||||
|
def pin_memory(self):
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
|
return type(self)(self.data.pin_memory(), self.batch_sizes,
|
||||||
|
bind(self.sorted_indices, lambda t: t.pin_memory()),
|
||||||
|
bind(self.unsorted_indices, lambda t: t.pin_memory()))
|
||||||
|
|
||||||
def cuda(self, *args, **kwargs):
|
def cuda(self, *args, **kwargs):
|
||||||
"""Returns a GPU copy if `self.data` not already on the GPU"""
|
"""Returns a GPU copy if `self.data` not already on the GPU"""
|
||||||
if self.is_cuda:
|
if self.is_cuda:
|
||||||
return self
|
return self
|
||||||
else:
|
else:
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes,
|
return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes,
|
||||||
bind(self.sorted_indices, lambda t: t.cuda(*args, **kwargs)),
|
bind(self.sorted_indices, lambda t: t.cuda(*args, **kwargs)),
|
||||||
bind(self.unsorted_indices, lambda t: t.cuda(*args, **kwargs)))
|
bind(self.unsorted_indices, lambda t: t.cuda(*args, **kwargs)))
|
||||||
@ -70,6 +100,8 @@ class PackedSequence(PackedSequence_):
|
|||||||
def cpu(self):
|
def cpu(self):
|
||||||
"""Returns a CPU copy if `self.data` not already on the CPU"""
|
"""Returns a CPU copy if `self.data` not already on the CPU"""
|
||||||
if self.is_cuda:
|
if self.is_cuda:
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.cpu(), self.batch_sizes,
|
return type(self)(self.data.cpu(), self.batch_sizes,
|
||||||
bind(self.sorted_indices, lambda t: t.cpu()),
|
bind(self.sorted_indices, lambda t: t.cpu()),
|
||||||
bind(self.unsorted_indices, lambda t: t.cpu()))
|
bind(self.unsorted_indices, lambda t: t.cpu()))
|
||||||
@ -78,41 +110,65 @@ class PackedSequence(PackedSequence_):
|
|||||||
|
|
||||||
def double(self):
|
def double(self):
|
||||||
r"""Returns copy with `self.data` cast to double type"""
|
r"""Returns copy with `self.data` cast to double type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.double(), self.batch_sizes,
|
return type(self)(self.data.double(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
def float(self):
|
def float(self):
|
||||||
r"""Returns copy with `self.data` cast to float type"""
|
r"""Returns copy with `self.data` cast to float type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.float(), self.batch_sizes,
|
return type(self)(self.data.float(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
def half(self):
|
def half(self):
|
||||||
r"""Returns copy with `self.data` cast to half type"""
|
r"""Returns copy with `self.data` cast to half type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.half(), self.batch_sizes,
|
return type(self)(self.data.half(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
def long(self):
|
def long(self):
|
||||||
r"""Returns copy with `self.data` cast to long type"""
|
r"""Returns copy with `self.data` cast to long type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.long(), self.batch_sizes,
|
return type(self)(self.data.long(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
def int(self):
|
def int(self):
|
||||||
r"""Returns copy with `self.data` cast to int type"""
|
r"""Returns copy with `self.data` cast to int type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.int(), self.batch_sizes,
|
return type(self)(self.data.int(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
def short(self):
|
def short(self):
|
||||||
r"""Returns copy with `self.data` cast to short type"""
|
r"""Returns copy with `self.data` cast to short type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.short(), self.batch_sizes,
|
return type(self)(self.data.short(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
def char(self):
|
def char(self):
|
||||||
r"""Returns copy with `self.data` cast to char type"""
|
r"""Returns copy with `self.data` cast to char type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.char(), self.batch_sizes,
|
return type(self)(self.data.char(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
def byte(self):
|
def byte(self):
|
||||||
r"""Returns copy with `self.data` cast to byte type"""
|
r"""Returns copy with `self.data` cast to byte type"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
return type(self)(self.data.byte(), self.batch_sizes,
|
return type(self)(self.data.byte(), self.batch_sizes,
|
||||||
self.sorted_indices, self.unsorted_indices)
|
self.sorted_indices, self.unsorted_indices)
|
||||||
|
|
||||||
@ -127,6 +183,9 @@ class PackedSequence(PackedSequence_):
|
|||||||
and :class:`torch.device`, then ``self`` is returned.
|
and :class:`torch.device`, then ``self`` is returned.
|
||||||
Otherwise, returns a copy with the desired configuration.
|
Otherwise, returns a copy with the desired configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Why not convert `batch_sizes`?
|
||||||
|
# See NOTE [ device and dtype of a PackedSequence ]
|
||||||
data = self.data.to(*args, **kwargs)
|
data = self.data.to(*args, **kwargs)
|
||||||
sorted_indices = self.sorted_indices
|
sorted_indices = self.sorted_indices
|
||||||
unsorted_indices = self.unsorted_indices
|
unsorted_indices = self.unsorted_indices
|
||||||
@ -145,6 +204,10 @@ class PackedSequence(PackedSequence_):
|
|||||||
r"""Returns true if `self.data` stored on a gpu"""
|
r"""Returns true if `self.data` stored on a gpu"""
|
||||||
return self.data.is_cuda
|
return self.data.is_cuda
|
||||||
|
|
||||||
|
def is_pinned(self):
|
||||||
|
r"""Returns true if `self.data` stored on in pinned memory"""
|
||||||
|
return self.data.is_pinned()
|
||||||
|
|
||||||
|
|
||||||
def invert_permutation(permutation):
|
def invert_permutation(permutation):
|
||||||
if permutation is None:
|
if permutation is None:
|
||||||
@ -158,12 +221,12 @@ def invert_permutation(permutation):
|
|||||||
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
|
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
|
||||||
r"""Packs a Tensor containing padded sequences of variable length.
|
r"""Packs a Tensor containing padded sequences of variable length.
|
||||||
|
|
||||||
Input can be of size ``T x B x *`` where `T` is the length of the longest sequence
|
:attr:`input` can be of size ``T x B x *`` where `T` is the length of the
|
||||||
(equal to ``lengths[0]``), `B` is the batch size, and `*` is any number of
|
longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
|
||||||
dimensions (including 0). If ``batch_first`` is True ``B x T x *`` inputs are
|
``*`` is any number of dimensions (including 0). If ``batch_first`` is
|
||||||
expected.
|
``True``, ``B x T x *`` :attr:`input` is expected.
|
||||||
|
|
||||||
For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` is
|
For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
|
||||||
``True``, the sequences should be sorted by length in a decreasing order, i.e.
|
``True``, the sequences should be sorted by length in a decreasing order, i.e.
|
||||||
``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
|
``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
|
||||||
one. `enforce_sorted = True` is only necessary for ONNX export.
|
one. `enforce_sorted = True` is only necessary for ONNX export.
|
||||||
|
|||||||
Reference in New Issue
Block a user