mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #147140 ## Changes - Add `to` implementation in `MaskedTensor` to support move `mask` to target device ## Test Result ```python In [1]: import torch ...: from torch.masked import as_masked_tensor ...: data = torch.tensor([1,2,3]) ...: mask = torch.tensor([True,False,True]) ...: mt = as_masked_tensor(data, mask).to('cuda') ...: mt.get_data().device, mt.get_mask().device /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) /home/zong/code/pytorch/torch/masked/maskedtensor/_ops_refs.py:354: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(new_data, _maybe_get_mask(args[0])) Out[1]: (device(type='cuda', index=0), device(type='cuda', index=0)) In [2]: mt.sum(dim=0) /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) Out[2]: MaskedTensor(4, True) ``` ```bash pytest test/test_maskedtensor.py -vv ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/151205 Approved by: https://github.com/ezyang