Files
pytorch/torch/masked
zeshengzong 216ba6e5f2 Fix MaskedTensor to device ignored mask (#151205)
Fixes #147140

## Changes

- Add `to` implementation in `MaskedTensor` to support move `mask` to target device

## Test Result

```python
In [1]: import torch
   ...: from torch.masked import as_masked_tensor
   ...: data = torch.tensor([1,2,3])
   ...: mask = torch.tensor([True,False,True])
   ...: mt = as_masked_tensor(data, mask).to('cuda')
   ...: mt.get_data().device, mt.get_mask().device
/home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  return MaskedTensor(data, mask)
/home/zong/code/pytorch/torch/masked/maskedtensor/_ops_refs.py:354: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  return MaskedTensor(new_data, _maybe_get_mask(args[0]))
Out[1]: (device(type='cuda', index=0), device(type='cuda', index=0))

In [2]: mt.sum(dim=0)
/home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  return MaskedTensor(data, mask)
Out[2]: MaskedTensor(4, True)

```

```bash
pytest test/test_maskedtensor.py -vv
```

![image](https://github.com/user-attachments/assets/640b809c-b4f0-4aca-a09e-04049017a745)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151205
Approved by: https://github.com/ezyang
2025-07-21 21:44:49 +00:00
..
2024-04-26 15:35:53 +00:00