Fix MaskedTensor to device ignored mask (#151205)

Fixes #147140

## Changes

- Add `to` implementation in `MaskedTensor` to support move `mask` to target device

## Test Result

```python
In [1]: import torch
   ...: from torch.masked import as_masked_tensor
   ...: data = torch.tensor([1,2,3])
   ...: mask = torch.tensor([True,False,True])
   ...: mt = as_masked_tensor(data, mask).to('cuda')
   ...: mt.get_data().device, mt.get_mask().device
/home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  return MaskedTensor(data, mask)
/home/zong/code/pytorch/torch/masked/maskedtensor/_ops_refs.py:354: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  return MaskedTensor(new_data, _maybe_get_mask(args[0]))
Out[1]: (device(type='cuda', index=0), device(type='cuda', index=0))

In [2]: mt.sum(dim=0)
/home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  return MaskedTensor(data, mask)
Out[2]: MaskedTensor(4, True)

```

```bash
pytest test/test_maskedtensor.py -vv
```

![image](https://github.com/user-attachments/assets/640b809c-b4f0-4aca-a09e-04049017a745)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151205
Approved by: https://github.com/ezyang
This commit is contained in:
zeshengzong
2025-07-21 21:44:44 +00:00
committed by PyTorch MergeBot
parent c774180e59
commit 216ba6e5f2
2 changed files with 30 additions and 1 deletions

View File

@ -236,6 +236,32 @@ class TestBasics(TestCase):
_compare_mt_t(sparse_mt, data)
_compare_mt_t(mt.grad, data.grad)
def test_to_device(self, device):
for sample in _generate_sample_data(device=device):
data = sample.input
mask = sample.kwargs["mask"]
mt = masked_tensor(data, mask, requires_grad=True)
new_device = torch.device("cuda") if device != "cuda" and torch.cuda.is_available() else torch.device("cpu")
mt_device = mt.to(new_device)
self.assertEqual(mt_device.device.type, new_device.type)
self.assertEqual(mt_device.get_mask().device.type, new_device.type)
self.assertEqual(mt_device.get_data().device.type, new_device.type)
def test_to_dtype(self, device):
for sample in _generate_sample_data(device=device):
data = sample.input
mask = sample.kwargs["mask"]
mt = masked_tensor(data, mask, requires_grad=True)
new_dtype = torch.float64 if data.dtype == torch.float32 else torch.float32
mt_dtype = mt.to(new_dtype)
self.assertEqual(mt_dtype.dtype, new_dtype)
self.assertEqual(mt_dtype.get_mask().dtype, torch.bool)
self.assertEqual(mt_dtype.get_data().dtype, new_dtype)
def test_to_dense(self, device):
samples = _generate_sample_data(
device=device,

View File

@ -351,7 +351,10 @@ def _apply_fn_on_data(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten._to_copy])
def _to_copy(func, *args, **kwargs):
new_data = func(_get_data(args[0]), *args[1:], **kwargs)
return MaskedTensor(new_data, _maybe_get_mask(args[0]))
cloned_kwargs = kwargs.copy()
cloned_kwargs["dtype"] = torch.bool
new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs)
return MaskedTensor(new_data, new_mask)
@register_dispatch_func([torch.ops.aten._softmax])