[BE] enable UFMT for torch/masked/ (#127715)

Part of #123062

- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127715
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Xuehai Pan
2024-06-03 22:01:46 +00:00
committed by PyTorch MergeBot
parent 406532f864
commit 01fc22056a
10 changed files with 206 additions and 156 deletions

View File

@ -1674,18 +1674,6 @@ exclude_patterns = [
'torch/hub.py',
'torch/library.py',
'torch/linalg/__init__.py',
# UFMT causes import cycle on masked
'torch/masked/__init__.py',
'torch/masked/_docs.py',
'torch/masked/_ops.py',
'torch/masked/maskedtensor/__init__.py',
'torch/masked/maskedtensor/_ops_refs.py',
'torch/masked/maskedtensor/binary.py',
'torch/masked/maskedtensor/core.py',
'torch/masked/maskedtensor/creation.py',
'torch/masked/maskedtensor/passthrough.py',
'torch/masked/maskedtensor/reductions.py',
'torch/masked/maskedtensor/unary.py',
'torch/monitor/__init__.py',
'torch/nested/__init__.py',
'torch/nn/__init__.py',

View File

@ -1,33 +1,34 @@
from .maskedtensor.core import is_masked_tensor, MaskedTensor
from .maskedtensor.creation import as_masked_tensor, masked_tensor
from ._ops import (
from torch.masked._ops import (
_canonical_dim,
_combine_input_and_mask,
_generate_docstring,
_reduction_identity,
_where,
_input_mask,
_output_mask,
_combine_input_and_mask,
sum,
prod,
cumsum,
cumprod,
_reduction_identity,
_where,
amax,
amin,
argmax,
argmin,
cumprod,
cumsum,
log_softmax,
logaddexp,
logsumexp,
mean,
median,
logsumexp,
logaddexp,
norm,
var,
std,
softmax,
log_softmax,
softmin,
normalize,
prod,
softmax,
softmin,
std,
sum,
var,
)
from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor
__all__ = [
"as_masked_tensor",

View File

@ -1,15 +1,12 @@
import warnings
# A workaround to support both TorchScript and MyPy:
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch import Tensor
from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor
from . import _docs
from torch import sym_float, Tensor
from torch._prims_common import corresponding_real_dtype
from torch import sym_float
from torch.masked import _docs
from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
from torch.masked.maskedtensor.creation import as_masked_tensor
if TYPE_CHECKING:
from torch.types import _dtype as DType
@ -469,7 +466,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
if d >= ndim or d < -ndim:
raise IndexError(
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})"
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
)
dims.append(d % ndim)
return tuple(sorted(dims))
@ -1420,7 +1417,6 @@ def median(
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
) -> Tensor:
"""\
{reduction_signature}
{reduction_descr}
@ -1487,46 +1483,45 @@ def logaddexp(
) -> Tensor:
"""logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
tensor. The :attr:`input` elements are masked out according to the boolean tensor
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
:attr:`other_mask`.
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
tensor. The :attr:`input` elements are masked out according to the boolean tensor
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
:attr:`other_mask`.
The shapes of a mask tensor and the tensor to be masked
don't need to match, but they must be :ref:`broadcastable
<broadcasting-semantics>` and the dimensionality of the mask
tensor must not be greater than of the tensor to be masked.
The shapes of a mask tensor and the tensor to be masked
don't need to match, but they must be :ref:`broadcastable
<broadcasting-semantics>` and the dimensionality of the mask
tensor must not be greater than of the tensor to be masked.
Args:
input (Tensor): the input tensor
other (Tensor): the second input tensor
Args:
input (Tensor): the input tensor
other (Tensor): the second input tensor
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the output tensor is
casted to :attr:`dtype` after the operation is
performed. Default: None.
input_mask (:class:`torch.Tensor`, optional): the boolean tensor
containing the binary mask of validity of :attr:`input` tensor elements.
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
other_mask (:class:`torch.Tensor`, optional): the boolean tensor
containing the binary mask of validity of :attr:`other` tensor elements.
Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the output tensor is
casted to :attr:`dtype` after the operation is
performed. Default: None.
input_mask (:class:`torch.Tensor`, optional): the boolean tensor
containing the binary mask of validity of :attr:`input` tensor elements.
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
other_mask (:class:`torch.Tensor`, optional): the boolean tensor
containing the binary mask of validity of :attr:`other` tensor elements.
Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
Example::
Example::
>>> input = torch.tensor([-100.0, -200, -300])
>>> input
tensor([-100., -200., -300.])
>>> other = torch.tensor([-1.0, -2, -3])
>>> other
tensor([-1., -2., -3.])
>>> mask = torch.tensor([True, False, True])
>>> mask
tensor([ True, False, True])
>>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
tensor([-1., -inf, -3.])
"""
>>> input = torch.tensor([-100.0, -200, -300])
>>> input
tensor([-100., -200., -300.])
>>> other = torch.tensor([-1.0, -2, -3])
>>> other
tensor([-1., -2., -3.])
>>> mask = torch.tensor([True, False, True])
>>> mask
tensor([ True, False, True])
>>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
tensor([-1., -inf, -3.])"""
if dtype is None:
dtype = input.dtype
if input.layout == torch.strided and other.layout == torch.strided:
@ -1586,7 +1581,9 @@ def _std_var(
mask: Optional[Tensor],
take_sqrt: Optional[bool],
) -> Tensor:
assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given"
assert (
unbiased is None or correction_opt is None
), "Only one of unbiased and correction may be given"
correction = 1.0
if unbiased is not None:
correction = 1.0 if unbiased else 0.0
@ -1632,8 +1629,11 @@ def _std_var(
if not keepdim:
count = count.reshape(total.shape)
if correction != 0:
real_dtype = (corresponding_real_dtype(compute_dtype)
if compute_dtype.is_complex else compute_dtype)
real_dtype = (
corresponding_real_dtype(compute_dtype)
if compute_dtype.is_complex
else compute_dtype
)
count = count.to(real_dtype)
count = torch.subtract(count, correction)
count = torch.maximum(count, count.new_zeros([]))

View File

@ -1,43 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Callable, Any, Dict, TYPE_CHECKING
from typing import Any, Callable, Dict, TYPE_CHECKING
import torch
if TYPE_CHECKING:
import torch._ops
from .binary import (
_apply_native_binary,
NATIVE_BINARY_FNS,
NATIVE_INPLACE_BINARY_FNS,
)
from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask
from .passthrough import (
_apply_pass_through_fn,
PASSTHROUGH_FNS
from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
from .core import (
_get_data,
_masks_match,
_maybe_get_mask,
is_masked_tensor,
MaskedTensor,
)
from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
from .reductions import (
_apply_reduction,
NATIVE_REDUCE_FNS,
TORCH_REDUCE_FNS,
TENSOR_REDUCE_FNS,
TORCH_REDUCE_FNS,
)
from .unary import (
_apply_native_unary,
NATIVE_UNARY_FNS,
NATIVE_INPLACE_UNARY_FNS,
)
from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
__all__ = [] # type: ignore[var-annotated]
def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None):
def _check_args_kwargs_length(
args, kwargs, error_prefix, len_args=None, len_kwargs=None
):
if len_args is not None and len_args != len(args):
raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}")
raise ValueError(
f"{error_prefix}: len(args) must be {len_args} but got {len(args)}"
)
if len_kwargs is not None and len_kwargs != len(kwargs):
raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}")
raise ValueError(
f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}"
)
class _MaskedContiguous(torch.autograd.Function):
@ -116,7 +118,9 @@ class _MaskedToSparseCsr(torch.autograd.Function):
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
if input._masked_data.ndim != 2:
raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}")
raise ValueError(
f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}"
)
if input.layout == torch.sparse_csr:
return input
@ -157,7 +161,11 @@ class _MaskedWhere(torch.autograd.Function):
_MASKEDTENSOR_FUNCTION_TABLE = {}
_function_fn_apply_map = {
(tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction,
(
tuple(NATIVE_REDUCE_FNS),
tuple(TORCH_REDUCE_FNS),
tuple(TENSOR_REDUCE_FNS),
): _apply_reduction,
}
for fn_map_list, apply_fn in _function_fn_apply_map.items():
@ -177,9 +185,11 @@ def register_function_func(ops):
def foo(func, *args, **kwargs):
<implementation>
"""
def wrapper(func):
for op in ops:
_MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
return wrapper
@ -190,7 +200,9 @@ def _general_function_reductions(func, *args, **kwargs):
@register_function_func([torch.Tensor.where, torch.where])
def _function_where(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0
)
return _MaskedWhere.apply(*args)
@ -216,6 +228,7 @@ def _function_to_sparse_csr(func, *args, **kwargs):
_MASKEDTENSOR_DISPATCH_TABLE: Dict["torch._ops.OpOverload", Callable[..., Any]] = {}
def register_dispatch_func(aten_ops):
"""
Used for registering a new __torch_dispatch__ function to MaskedTensor
@ -227,9 +240,11 @@ def register_dispatch_func(aten_ops):
def foo(func, *args, **kwargs):
<implementation>
"""
def wrapper(func):
for aten_op in aten_ops:
_MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
return wrapper
@ -272,9 +287,7 @@ def layout(func, *args, **kwargs):
def is_contiguous(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:
raise ValueError(
"MaskedTensors with sparse data do not have is_contiguous"
)
raise ValueError("MaskedTensors with sparse data do not have is_contiguous")
return func(data, *args[1:], **kwargs)
@ -301,9 +314,7 @@ def is_non_overlapping_and_dense(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten.contiguous])
def contiguous(func, *args, **kwargs):
if _get_data(args[0]).is_sparse:
raise ValueError(
"MaskedTensors with sparse data do not have contiguous"
)
raise ValueError("MaskedTensors with sparse data do not have contiguous")
return _MaskedContiguous.apply(args[0])
@ -313,9 +324,13 @@ def new_empty_strided(func, *args, **kwargs):
data = _get_data(args[0])
mask = _maybe_get_mask(args[0])
if tuple(args[1]) != tuple(data.size()):
raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()")
raise ValueError(
f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()"
)
if tuple(args[2]) != tuple(data.stride()):
raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()")
raise ValueError(
f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()"
)
return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
@ -339,7 +354,9 @@ def _to_copy(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten._softmax])
def _softmax(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
)
data = _get_data(args[0])
mask = _maybe_get_mask(args[0])
result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
@ -359,7 +376,9 @@ def _softmax_backward_data(func, *args, **kwargs):
grad, output, dim, input_dtype = args
if is_masked_tensor(grad) and is_masked_tensor(output):
if not _masks_match(grad, output):
raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match")
raise ValueError(
"__torch_dispatch__, {func}: expected the masks of grad and output to match"
)
grad_data = _get_data(grad)
new_grad_data = torch.ops.aten._masked_softmax_backward(
grad_data,
@ -370,7 +389,9 @@ def _softmax_backward_data(func, *args, **kwargs):
res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
return res
else:
raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors")
raise ValueError(
f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors"
)
@register_dispatch_func([torch.ops.aten.copy_])
@ -384,7 +405,9 @@ def copy_(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten.where])
def where(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mx = args[1]
@ -400,7 +423,9 @@ def where(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten._to_sparse])
def _to_sparse(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mt = args[0]
@ -415,7 +440,9 @@ def _to_sparse(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten._to_sparse_csr])
def _to_sparse_csr(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mt = args[0]
@ -430,7 +457,9 @@ def _to_sparse_csr(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten._to_dense])
def _to_dense(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mt = args[0]
@ -444,14 +473,18 @@ def _to_dense(func, *args, **kwargs):
@register_dispatch_func([torch.ops.aten._indices])
def _indices(func, *args, **kwargs):
# Assumes data is sparse
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
data = _get_data(args[0]).indices()
return MaskedTensor(data, torch.ones_like(data).bool())
@register_dispatch_func([torch.ops.aten._values])
def _values(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
data = _get_data(args[0]).values()
return MaskedTensor(data, torch.ones_like(data).bool())

View File

@ -2,7 +2,14 @@
import torch
from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor
from .core import (
_map_mt_args_kwargs,
_masks_match,
_tensors_match,
_wrap_result,
is_masked_tensor,
)
__all__ = [] # type: ignore[var-annotated]
@ -79,25 +86,22 @@ def _binary_helper(fn, args, kwargs, inplace):
raise ValueError("len(kwargs) must equal 0")
for a in args[2:]:
if torch.is_tensor(a):
raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs")
raise TypeError(
"MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
)
if not _masks_match(*args[:2]):
raise ValueError(
"Input masks must match. If you need support for this, please open an issue on Github."
)
data_args, data_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x.get_data()
)
mask_args, mask_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x.get_mask()
)
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
args0_layout = data_args[0].layout
same_layout = (
(torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and
(args0_layout == data_args[1].layout)
)
torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
) and (args0_layout == data_args[1].layout)
if args0_layout == torch.sparse_coo:
if same_layout:
@ -106,7 +110,9 @@ def _binary_helper(fn, args, kwargs, inplace):
"sparse_coo indices must match. If you need support for this, please open an issue on Github."
)
if data_args[0].size() != data_args[1].size():
raise ValueError("input1 and input2 must have the same size for binary functions.")
raise ValueError(
"input1 and input2 must have the same size for binary functions."
)
data_args[1] = data_args[1].values()

View File

@ -13,7 +13,7 @@ __all__ = [
def is_masked_tensor(a):
r""" Returns True if the input is a MaskedTensor, else False
r"""Returns True if the input is a MaskedTensor, else False
Args:
a: any input
@ -35,7 +35,9 @@ def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
if is_masked_tensor(a) or is_masked_tensor(b):
raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
if a.layout != b.layout:
raise ValueError(f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}")
raise ValueError(
f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}"
)
if a.dtype != b.dtype:
b = b.type(a.dtype)
@ -108,9 +110,7 @@ def _masked_tensor_str(data, mask, formatter):
formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
for d in data
]
max_len = max(
8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask)
)
max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
return (
"["
+ ", ".join(
@ -153,13 +153,21 @@ class MaskedTensor(torch.Tensor):
kwargs["requires_grad"] = requires_grad
kwargs["dispatch_sizes_strides_policy"] = "strides"
kwargs["dispatch_layout"] = True
warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.masked "
"module for further information about the project."), UserWarning)
warnings.warn(
(
"The PyTorch API of MaskedTensors is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.masked "
"module for further information about the project."
),
UserWarning,
)
if data.requires_grad:
warnings.warn("It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
"To avoid this, you can use data.clone().detach()", UserWarning)
warnings.warn(
"It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
"To avoid this, you can use data.clone().detach()",
UserWarning,
)
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
def _preprocess_data(self, data, mask):
@ -184,17 +192,23 @@ class MaskedTensor(torch.Tensor):
data = self._masked_data
mask = self.get_mask()
if type(data) != type(mask):
raise TypeError(f"data and mask must have the same type. Got {type(data)} and {type(mask)}")
raise TypeError(
f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
)
if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
raise TypeError(f"data layout of {data.layout} is not supported.")
if data.layout == torch.sparse_coo:
if not _tensors_match(data.indices(), mask.indices(), exact=True):
raise ValueError("data and mask are both sparse COO tensors but do not have the same indices.")
raise ValueError(
"data and mask are both sparse COO tensors but do not have the same indices."
)
elif data.layout == torch.sparse_csr:
if not _tensors_match(
data.crow_indices(), mask.crow_indices(), exact=True
) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
raise ValueError("data and mask are both sparse CSR tensors but do not share either crow or col indices.")
raise ValueError(
"data and mask are both sparse CSR tensors but do not share either crow or col indices."
)
if mask.dtype != torch.bool:
raise TypeError("mask must have dtype bool.")
if not (
@ -219,7 +233,8 @@ class MaskedTensor(torch.Tensor):
@staticmethod
def _from_values(data, mask):
""" Differentiable constructor for MaskedTensor """
"""Differentiable constructor for MaskedTensor"""
class Constructor(torch.autograd.Function):
@staticmethod
def forward(ctx, data, mask):
@ -265,6 +280,7 @@ class MaskedTensor(torch.Tensor):
kwargs = kwargs or {}
from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
if func in _MASKEDTENSOR_FUNCTION_TABLE:
return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
@ -286,6 +302,7 @@ class MaskedTensor(torch.Tensor):
func = func.overloadpacket
from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
if func in _MASKEDTENSOR_DISPATCH_TABLE:
return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)

View File

@ -2,20 +2,21 @@
from .core import MaskedTensor
__all__ = [
"as_masked_tensor",
"masked_tensor",
]
""""
These two factory functions are intended to mirror
torch.tensor - guaranteed to be a leaf node
torch.as_tensor - differentiable constructor that preserves the autograd history
"""
# These two factory functions are intended to mirror
# torch.tensor - guaranteed to be a leaf node
# torch.as_tensor - differentiable constructor that preserves the autograd history
def masked_tensor(data, mask, requires_grad=False):
return MaskedTensor(data, mask, requires_grad)
def as_masked_tensor(data, mask):
return MaskedTensor._from_values(data, mask)

View File

@ -10,6 +10,7 @@ import torch
from .core import _map_mt_args_kwargs, _wrap_result
__all__ = [] # type: ignore[var-annotated]

View File

@ -7,6 +7,7 @@ import torch
from .core import is_masked_tensor
from .creation import as_masked_tensor, masked_tensor
__all__ = [] # type: ignore[var-annotated]
@ -159,6 +160,7 @@ NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
def _is_reduction(fn):
return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP

View File

@ -4,6 +4,7 @@ import torch
from .core import _map_mt_args_kwargs, _wrap_result
__all__ = [] # type: ignore[var-annotated]
@ -108,18 +109,18 @@ UNARY_NAMES_UNSUPPORTED = [
def _unary_helper(fn, args, kwargs, inplace):
if len(kwargs) != 0:
raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. "
"If you need support for this, please open an issue on Github.")
raise ValueError(
"MaskedTensor unary ops require that len(kwargs) == 0. "
"If you need support for this, please open an issue on Github."
)
for a in args[1:]:
if torch.is_tensor(a):
raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments")
raise TypeError(
"MaskedTensor unary ops do not support additional Tensor arguments"
)
mask_args, mask_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_mask
)
data_args, data_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_data
)
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask)
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data)
if args[0].layout == torch.sparse_coo:
data_args[0] = data_args[0].coalesce()