From 216ba6e5f235bbfa0b025303ad4aa5ee473c5a8b Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Mon, 21 Jul 2025 21:44:44 +0000 Subject: [PATCH] Fix `MaskedTensor` to device ignored mask (#151205) Fixes #147140 ## Changes - Add `to` implementation in `MaskedTensor` to support move `mask` to target device ## Test Result ```python In [1]: import torch ...: from torch.masked import as_masked_tensor ...: data = torch.tensor([1,2,3]) ...: mask = torch.tensor([True,False,True]) ...: mt = as_masked_tensor(data, mask).to('cuda') ...: mt.get_data().device, mt.get_mask().device /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) /home/zong/code/pytorch/torch/masked/maskedtensor/_ops_refs.py:354: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(new_data, _maybe_get_mask(args[0])) Out[1]: (device(type='cuda', index=0), device(type='cuda', index=0)) In [2]: mt.sum(dim=0) /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) Out[2]: MaskedTensor(4, True) ``` ```bash pytest test/test_maskedtensor.py -vv ``` ![image](https://github.com/user-attachments/assets/640b809c-b4f0-4aca-a09e-04049017a745) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151205 Approved by: https://github.com/ezyang --- test/test_maskedtensor.py | 26 ++++++++++++++++++++++++++ torch/masked/maskedtensor/_ops_refs.py | 5 ++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index db1ffbc38c1f..03c05c7ea6da 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -236,6 +236,32 @@ class TestBasics(TestCase): _compare_mt_t(sparse_mt, data) _compare_mt_t(mt.grad, data.grad) + def test_to_device(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_device = torch.device("cuda") if device != "cuda" and torch.cuda.is_available() else torch.device("cpu") + mt_device = mt.to(new_device) + + self.assertEqual(mt_device.device.type, new_device.type) + self.assertEqual(mt_device.get_mask().device.type, new_device.type) + self.assertEqual(mt_device.get_data().device.type, new_device.type) + + def test_to_dtype(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_dtype = torch.float64 if data.dtype == torch.float32 else torch.float32 + mt_dtype = mt.to(new_dtype) + + self.assertEqual(mt_dtype.dtype, new_dtype) + self.assertEqual(mt_dtype.get_mask().dtype, torch.bool) + self.assertEqual(mt_dtype.get_data().dtype, new_dtype) + def test_to_dense(self, device): samples = _generate_sample_data( device=device, diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 719df7eac464..8135f149a1bf 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -351,7 +351,10 @@ def _apply_fn_on_data(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_copy]) def _to_copy(func, *args, **kwargs): new_data = func(_get_data(args[0]), *args[1:], **kwargs) - return MaskedTensor(new_data, _maybe_get_mask(args[0])) + cloned_kwargs = kwargs.copy() + cloned_kwargs["dtype"] = torch.bool + new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs) + return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten._softmax])