mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update pin memory related APIs to not pass 'device' argument (#131858)
Based on https://github.com/pytorch/pytorch/pull/126376, this PR tries to update all PT callers (e.g., `Tensor.is_pinned()`, `Tensor.pin_memory()`) to not pass `device` argument. As for `storage/untyped_storage.is_pinned()/pin_memory()`, we keep the `device` argument but passing `device` is discouraged. And if not given, the default `device` is still 'cuda' for BC. Additionally, based on device-agnostic pin_memory, `pin_memory_device` argument of `torch.utils.data.DataLoader` is discouraged now. For BC, explictly passing this argument is still effective. If not given, the default `device` will be the current accelerator. Fixes #124908 Relates https://github.com/pytorch/pytorch/pull/126376 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131858 Approved by: https://github.com/albanD Co-authored-by: albanD <desmaison.alban@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
0dca756832
commit
c07dc64017
@ -3092,17 +3092,15 @@ class TestDictDataLoader(TestCase):
|
||||
self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda"
|
||||
)
|
||||
for sample in loader:
|
||||
self.assertTrue(sample["a_tensor"].is_pinned(device="cuda"))
|
||||
self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda"))
|
||||
self.assertTrue(sample["a_tensor"].is_pinned())
|
||||
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_pin_memory_with_only_device(self):
|
||||
loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda")
|
||||
for sample in loader:
|
||||
self.assertFalse(sample["a_tensor"].is_pinned(device="cuda"))
|
||||
self.assertFalse(
|
||||
sample["another_dict"]["a_number"].is_pinned(device="cuda")
|
||||
)
|
||||
self.assertFalse(sample["a_tensor"].is_pinned())
|
||||
self.assertFalse(sample["another_dict"]["a_number"].is_pinned())
|
||||
|
||||
|
||||
class DummyDataset(torch.utils.data.Dataset):
|
||||
|
Reference in New Issue
Block a user