mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix MaskedTensor
to device ignored mask (#151205)
Fixes #147140 ## Changes - Add `to` implementation in `MaskedTensor` to support move `mask` to target device ## Test Result ```python In [1]: import torch ...: from torch.masked import as_masked_tensor ...: data = torch.tensor([1,2,3]) ...: mask = torch.tensor([True,False,True]) ...: mt = as_masked_tensor(data, mask).to('cuda') ...: mt.get_data().device, mt.get_mask().device /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) /home/zong/code/pytorch/torch/masked/maskedtensor/_ops_refs.py:354: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(new_data, _maybe_get_mask(args[0])) Out[1]: (device(type='cuda', index=0), device(type='cuda', index=0)) In [2]: mt.sum(dim=0) /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) Out[2]: MaskedTensor(4, True) ``` ```bash pytest test/test_maskedtensor.py -vv ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/151205 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c774180e59
commit
216ba6e5f2
@ -236,6 +236,32 @@ class TestBasics(TestCase):
|
|||||||
_compare_mt_t(sparse_mt, data)
|
_compare_mt_t(sparse_mt, data)
|
||||||
_compare_mt_t(mt.grad, data.grad)
|
_compare_mt_t(mt.grad, data.grad)
|
||||||
|
|
||||||
|
def test_to_device(self, device):
|
||||||
|
for sample in _generate_sample_data(device=device):
|
||||||
|
data = sample.input
|
||||||
|
mask = sample.kwargs["mask"]
|
||||||
|
mt = masked_tensor(data, mask, requires_grad=True)
|
||||||
|
|
||||||
|
new_device = torch.device("cuda") if device != "cuda" and torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
mt_device = mt.to(new_device)
|
||||||
|
|
||||||
|
self.assertEqual(mt_device.device.type, new_device.type)
|
||||||
|
self.assertEqual(mt_device.get_mask().device.type, new_device.type)
|
||||||
|
self.assertEqual(mt_device.get_data().device.type, new_device.type)
|
||||||
|
|
||||||
|
def test_to_dtype(self, device):
|
||||||
|
for sample in _generate_sample_data(device=device):
|
||||||
|
data = sample.input
|
||||||
|
mask = sample.kwargs["mask"]
|
||||||
|
mt = masked_tensor(data, mask, requires_grad=True)
|
||||||
|
|
||||||
|
new_dtype = torch.float64 if data.dtype == torch.float32 else torch.float32
|
||||||
|
mt_dtype = mt.to(new_dtype)
|
||||||
|
|
||||||
|
self.assertEqual(mt_dtype.dtype, new_dtype)
|
||||||
|
self.assertEqual(mt_dtype.get_mask().dtype, torch.bool)
|
||||||
|
self.assertEqual(mt_dtype.get_data().dtype, new_dtype)
|
||||||
|
|
||||||
def test_to_dense(self, device):
|
def test_to_dense(self, device):
|
||||||
samples = _generate_sample_data(
|
samples = _generate_sample_data(
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -351,7 +351,10 @@ def _apply_fn_on_data(func, *args, **kwargs):
|
|||||||
@register_dispatch_func([torch.ops.aten._to_copy])
|
@register_dispatch_func([torch.ops.aten._to_copy])
|
||||||
def _to_copy(func, *args, **kwargs):
|
def _to_copy(func, *args, **kwargs):
|
||||||
new_data = func(_get_data(args[0]), *args[1:], **kwargs)
|
new_data = func(_get_data(args[0]), *args[1:], **kwargs)
|
||||||
return MaskedTensor(new_data, _maybe_get_mask(args[0]))
|
cloned_kwargs = kwargs.copy()
|
||||||
|
cloned_kwargs["dtype"] = torch.bool
|
||||||
|
new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs)
|
||||||
|
return MaskedTensor(new_data, new_mask)
|
||||||
|
|
||||||
|
|
||||||
@register_dispatch_func([torch.ops.aten._softmax])
|
@register_dispatch_func([torch.ops.aten._softmax])
|
||||||
|
Reference in New Issue
Block a user