Update pin memory related APIs to not pass 'device' argument (#131858)

Based on https://github.com/pytorch/pytorch/pull/126376, this PR tries to update all PT callers (e.g., `Tensor.is_pinned()`, `Tensor.pin_memory()`) to not pass `device` argument.
As for `storage/untyped_storage.is_pinned()/pin_memory()`, we keep the `device` argument but passing `device` is discouraged. And if not given, the default `device` is still 'cuda' for BC.
Additionally, based on device-agnostic pin_memory, `pin_memory_device` argument of `torch.utils.data.DataLoader` is discouraged  now. For BC, explictly passing this argument is still effective. If not given, the default `device` will be the current accelerator.

Fixes #124908
Relates https://github.com/pytorch/pytorch/pull/126376

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131858
Approved by: https://github.com/albanD

Co-authored-by: albanD <desmaison.alban@gmail.com>
This commit is contained in:
wizzniu
2025-01-15 17:23:33 +00:00
committed by PyTorch MergeBot
parent 0dca756832
commit c07dc64017
9 changed files with 61 additions and 31 deletions

View File

@ -3092,17 +3092,15 @@ class TestDictDataLoader(TestCase):
self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda"
)
for sample in loader:
self.assertTrue(sample["a_tensor"].is_pinned(device="cuda"))
self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda"))
self.assertTrue(sample["a_tensor"].is_pinned())
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_pin_memory_with_only_device(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda")
for sample in loader:
self.assertFalse(sample["a_tensor"].is_pinned(device="cuda"))
self.assertFalse(
sample["another_dict"]["a_number"].is_pinned(device="cuda")
)
self.assertFalse(sample["a_tensor"].is_pinned())
self.assertFalse(sample["another_dict"]["a_number"].is_pinned())
class DummyDataset(torch.utils.data.Dataset):