mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] enable UFMT for torch/masked/
(#127715)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127715 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
406532f864
commit
01fc22056a
@ -1674,18 +1674,6 @@ exclude_patterns = [
|
||||
'torch/hub.py',
|
||||
'torch/library.py',
|
||||
'torch/linalg/__init__.py',
|
||||
# UFMT causes import cycle on masked
|
||||
'torch/masked/__init__.py',
|
||||
'torch/masked/_docs.py',
|
||||
'torch/masked/_ops.py',
|
||||
'torch/masked/maskedtensor/__init__.py',
|
||||
'torch/masked/maskedtensor/_ops_refs.py',
|
||||
'torch/masked/maskedtensor/binary.py',
|
||||
'torch/masked/maskedtensor/core.py',
|
||||
'torch/masked/maskedtensor/creation.py',
|
||||
'torch/masked/maskedtensor/passthrough.py',
|
||||
'torch/masked/maskedtensor/reductions.py',
|
||||
'torch/masked/maskedtensor/unary.py',
|
||||
'torch/monitor/__init__.py',
|
||||
'torch/nested/__init__.py',
|
||||
'torch/nn/__init__.py',
|
||||
|
@ -1,33 +1,34 @@
|
||||
from .maskedtensor.core import is_masked_tensor, MaskedTensor
|
||||
from .maskedtensor.creation import as_masked_tensor, masked_tensor
|
||||
from ._ops import (
|
||||
from torch.masked._ops import (
|
||||
_canonical_dim,
|
||||
_combine_input_and_mask,
|
||||
_generate_docstring,
|
||||
_reduction_identity,
|
||||
_where,
|
||||
_input_mask,
|
||||
_output_mask,
|
||||
_combine_input_and_mask,
|
||||
sum,
|
||||
prod,
|
||||
cumsum,
|
||||
cumprod,
|
||||
_reduction_identity,
|
||||
_where,
|
||||
amax,
|
||||
amin,
|
||||
argmax,
|
||||
argmin,
|
||||
cumprod,
|
||||
cumsum,
|
||||
log_softmax,
|
||||
logaddexp,
|
||||
logsumexp,
|
||||
mean,
|
||||
median,
|
||||
logsumexp,
|
||||
logaddexp,
|
||||
norm,
|
||||
var,
|
||||
std,
|
||||
softmax,
|
||||
log_softmax,
|
||||
softmin,
|
||||
normalize,
|
||||
prod,
|
||||
softmax,
|
||||
softmin,
|
||||
std,
|
||||
sum,
|
||||
var,
|
||||
)
|
||||
from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
|
||||
from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor
|
||||
|
||||
|
||||
__all__ = [
|
||||
"as_masked_tensor",
|
||||
|
@ -1,15 +1,12 @@
|
||||
|
||||
import warnings
|
||||
|
||||
# A workaround to support both TorchScript and MyPy:
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor
|
||||
from . import _docs
|
||||
from torch import sym_float, Tensor
|
||||
from torch._prims_common import corresponding_real_dtype
|
||||
from torch import sym_float
|
||||
from torch.masked import _docs
|
||||
from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
|
||||
from torch.masked.maskedtensor.creation import as_masked_tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _dtype as DType
|
||||
@ -469,7 +466,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
|
||||
raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
|
||||
if d >= ndim or d < -ndim:
|
||||
raise IndexError(
|
||||
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})"
|
||||
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
|
||||
)
|
||||
dims.append(d % ndim)
|
||||
return tuple(sorted(dims))
|
||||
@ -1420,7 +1417,6 @@ def median(
|
||||
dtype: Optional[DType] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
|
||||
"""\
|
||||
{reduction_signature}
|
||||
{reduction_descr}
|
||||
@ -1487,46 +1483,45 @@ def logaddexp(
|
||||
) -> Tensor:
|
||||
"""logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
|
||||
|
||||
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
|
||||
tensor. The :attr:`input` elements are masked out according to the boolean tensor
|
||||
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
|
||||
:attr:`other_mask`.
|
||||
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
|
||||
tensor. The :attr:`input` elements are masked out according to the boolean tensor
|
||||
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
|
||||
:attr:`other_mask`.
|
||||
|
||||
The shapes of a mask tensor and the tensor to be masked
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the mask
|
||||
tensor must not be greater than of the tensor to be masked.
|
||||
The shapes of a mask tensor and the tensor to be masked
|
||||
don't need to match, but they must be :ref:`broadcastable
|
||||
<broadcasting-semantics>` and the dimensionality of the mask
|
||||
tensor must not be greater than of the tensor to be masked.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
other (Tensor): the second input tensor
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
other (Tensor): the second input tensor
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the output tensor is
|
||||
casted to :attr:`dtype` after the operation is
|
||||
performed. Default: None.
|
||||
input_mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of :attr:`input` tensor elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
other_mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of :attr:`other` tensor elements.
|
||||
Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the output tensor is
|
||||
casted to :attr:`dtype` after the operation is
|
||||
performed. Default: None.
|
||||
input_mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of :attr:`input` tensor elements.
|
||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
|
||||
other_mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||
containing the binary mask of validity of :attr:`other` tensor elements.
|
||||
Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
|
||||
|
||||
Example::
|
||||
Example::
|
||||
|
||||
>>> input = torch.tensor([-100.0, -200, -300])
|
||||
>>> input
|
||||
tensor([-100., -200., -300.])
|
||||
>>> other = torch.tensor([-1.0, -2, -3])
|
||||
>>> other
|
||||
tensor([-1., -2., -3.])
|
||||
>>> mask = torch.tensor([True, False, True])
|
||||
>>> mask
|
||||
tensor([ True, False, True])
|
||||
>>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
|
||||
tensor([-1., -inf, -3.])
|
||||
"""
|
||||
>>> input = torch.tensor([-100.0, -200, -300])
|
||||
>>> input
|
||||
tensor([-100., -200., -300.])
|
||||
>>> other = torch.tensor([-1.0, -2, -3])
|
||||
>>> other
|
||||
tensor([-1., -2., -3.])
|
||||
>>> mask = torch.tensor([True, False, True])
|
||||
>>> mask
|
||||
tensor([ True, False, True])
|
||||
>>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
|
||||
tensor([-1., -inf, -3.])"""
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if input.layout == torch.strided and other.layout == torch.strided:
|
||||
@ -1586,7 +1581,9 @@ def _std_var(
|
||||
mask: Optional[Tensor],
|
||||
take_sqrt: Optional[bool],
|
||||
) -> Tensor:
|
||||
assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given"
|
||||
assert (
|
||||
unbiased is None or correction_opt is None
|
||||
), "Only one of unbiased and correction may be given"
|
||||
correction = 1.0
|
||||
if unbiased is not None:
|
||||
correction = 1.0 if unbiased else 0.0
|
||||
@ -1632,8 +1629,11 @@ def _std_var(
|
||||
if not keepdim:
|
||||
count = count.reshape(total.shape)
|
||||
if correction != 0:
|
||||
real_dtype = (corresponding_real_dtype(compute_dtype)
|
||||
if compute_dtype.is_complex else compute_dtype)
|
||||
real_dtype = (
|
||||
corresponding_real_dtype(compute_dtype)
|
||||
if compute_dtype.is_complex
|
||||
else compute_dtype
|
||||
)
|
||||
count = count.to(real_dtype)
|
||||
count = torch.subtract(count, correction)
|
||||
count = torch.maximum(count, count.new_zeros([]))
|
||||
|
@ -1,43 +1,45 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Any, Dict, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch._ops
|
||||
|
||||
from .binary import (
|
||||
_apply_native_binary,
|
||||
NATIVE_BINARY_FNS,
|
||||
NATIVE_INPLACE_BINARY_FNS,
|
||||
)
|
||||
from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask
|
||||
from .passthrough import (
|
||||
_apply_pass_through_fn,
|
||||
PASSTHROUGH_FNS
|
||||
from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
|
||||
from .core import (
|
||||
_get_data,
|
||||
_masks_match,
|
||||
_maybe_get_mask,
|
||||
is_masked_tensor,
|
||||
MaskedTensor,
|
||||
)
|
||||
from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
|
||||
from .reductions import (
|
||||
_apply_reduction,
|
||||
NATIVE_REDUCE_FNS,
|
||||
TORCH_REDUCE_FNS,
|
||||
TENSOR_REDUCE_FNS,
|
||||
TORCH_REDUCE_FNS,
|
||||
)
|
||||
from .unary import (
|
||||
_apply_native_unary,
|
||||
NATIVE_UNARY_FNS,
|
||||
NATIVE_INPLACE_UNARY_FNS,
|
||||
)
|
||||
from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None):
|
||||
def _check_args_kwargs_length(
|
||||
args, kwargs, error_prefix, len_args=None, len_kwargs=None
|
||||
):
|
||||
if len_args is not None and len_args != len(args):
|
||||
raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}")
|
||||
raise ValueError(
|
||||
f"{error_prefix}: len(args) must be {len_args} but got {len(args)}"
|
||||
)
|
||||
if len_kwargs is not None and len_kwargs != len(kwargs):
|
||||
raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}")
|
||||
raise ValueError(
|
||||
f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}"
|
||||
)
|
||||
|
||||
|
||||
class _MaskedContiguous(torch.autograd.Function):
|
||||
@ -116,7 +118,9 @@ class _MaskedToSparseCsr(torch.autograd.Function):
|
||||
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
|
||||
|
||||
if input._masked_data.ndim != 2:
|
||||
raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}")
|
||||
raise ValueError(
|
||||
f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}"
|
||||
)
|
||||
|
||||
if input.layout == torch.sparse_csr:
|
||||
return input
|
||||
@ -157,7 +161,11 @@ class _MaskedWhere(torch.autograd.Function):
|
||||
_MASKEDTENSOR_FUNCTION_TABLE = {}
|
||||
|
||||
_function_fn_apply_map = {
|
||||
(tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction,
|
||||
(
|
||||
tuple(NATIVE_REDUCE_FNS),
|
||||
tuple(TORCH_REDUCE_FNS),
|
||||
tuple(TENSOR_REDUCE_FNS),
|
||||
): _apply_reduction,
|
||||
}
|
||||
|
||||
for fn_map_list, apply_fn in _function_fn_apply_map.items():
|
||||
@ -177,9 +185,11 @@ def register_function_func(ops):
|
||||
def foo(func, *args, **kwargs):
|
||||
<implementation>
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
for op in ops:
|
||||
_MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -190,7 +200,9 @@ def _general_function_reductions(func, *args, **kwargs):
|
||||
|
||||
@register_function_func([torch.Tensor.where, torch.where])
|
||||
def _function_where(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0
|
||||
)
|
||||
return _MaskedWhere.apply(*args)
|
||||
|
||||
|
||||
@ -216,6 +228,7 @@ def _function_to_sparse_csr(func, *args, **kwargs):
|
||||
|
||||
_MASKEDTENSOR_DISPATCH_TABLE: Dict["torch._ops.OpOverload", Callable[..., Any]] = {}
|
||||
|
||||
|
||||
def register_dispatch_func(aten_ops):
|
||||
"""
|
||||
Used for registering a new __torch_dispatch__ function to MaskedTensor
|
||||
@ -227,9 +240,11 @@ def register_dispatch_func(aten_ops):
|
||||
def foo(func, *args, **kwargs):
|
||||
<implementation>
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
for aten_op in aten_ops:
|
||||
_MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -272,9 +287,7 @@ def layout(func, *args, **kwargs):
|
||||
def is_contiguous(func, *args, **kwargs):
|
||||
data = _get_data(args[0])
|
||||
if data.is_sparse:
|
||||
raise ValueError(
|
||||
"MaskedTensors with sparse data do not have is_contiguous"
|
||||
)
|
||||
raise ValueError("MaskedTensors with sparse data do not have is_contiguous")
|
||||
return func(data, *args[1:], **kwargs)
|
||||
|
||||
|
||||
@ -301,9 +314,7 @@ def is_non_overlapping_and_dense(func, *args, **kwargs):
|
||||
@register_dispatch_func([torch.ops.aten.contiguous])
|
||||
def contiguous(func, *args, **kwargs):
|
||||
if _get_data(args[0]).is_sparse:
|
||||
raise ValueError(
|
||||
"MaskedTensors with sparse data do not have contiguous"
|
||||
)
|
||||
raise ValueError("MaskedTensors with sparse data do not have contiguous")
|
||||
return _MaskedContiguous.apply(args[0])
|
||||
|
||||
|
||||
@ -313,9 +324,13 @@ def new_empty_strided(func, *args, **kwargs):
|
||||
data = _get_data(args[0])
|
||||
mask = _maybe_get_mask(args[0])
|
||||
if tuple(args[1]) != tuple(data.size()):
|
||||
raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()")
|
||||
raise ValueError(
|
||||
f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()"
|
||||
)
|
||||
if tuple(args[2]) != tuple(data.stride()):
|
||||
raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()")
|
||||
raise ValueError(
|
||||
f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()"
|
||||
)
|
||||
return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
|
||||
|
||||
|
||||
@ -339,7 +354,9 @@ def _to_copy(func, *args, **kwargs):
|
||||
|
||||
@register_dispatch_func([torch.ops.aten._softmax])
|
||||
def _softmax(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
|
||||
)
|
||||
data = _get_data(args[0])
|
||||
mask = _maybe_get_mask(args[0])
|
||||
result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
|
||||
@ -359,7 +376,9 @@ def _softmax_backward_data(func, *args, **kwargs):
|
||||
grad, output, dim, input_dtype = args
|
||||
if is_masked_tensor(grad) and is_masked_tensor(output):
|
||||
if not _masks_match(grad, output):
|
||||
raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match")
|
||||
raise ValueError(
|
||||
"__torch_dispatch__, {func}: expected the masks of grad and output to match"
|
||||
)
|
||||
grad_data = _get_data(grad)
|
||||
new_grad_data = torch.ops.aten._masked_softmax_backward(
|
||||
grad_data,
|
||||
@ -370,7 +389,9 @@ def _softmax_backward_data(func, *args, **kwargs):
|
||||
res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
|
||||
return res
|
||||
else:
|
||||
raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors")
|
||||
raise ValueError(
|
||||
f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors"
|
||||
)
|
||||
|
||||
|
||||
@register_dispatch_func([torch.ops.aten.copy_])
|
||||
@ -384,7 +405,9 @@ def copy_(func, *args, **kwargs):
|
||||
|
||||
@register_dispatch_func([torch.ops.aten.where])
|
||||
def where(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
|
||||
)
|
||||
if not torch.is_tensor(args[0]):
|
||||
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
|
||||
mx = args[1]
|
||||
@ -400,7 +423,9 @@ def where(func, *args, **kwargs):
|
||||
|
||||
@register_dispatch_func([torch.ops.aten._to_sparse])
|
||||
def _to_sparse(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
|
||||
)
|
||||
if not torch.is_tensor(args[0]):
|
||||
raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
|
||||
mt = args[0]
|
||||
@ -415,7 +440,9 @@ def _to_sparse(func, *args, **kwargs):
|
||||
|
||||
@register_dispatch_func([torch.ops.aten._to_sparse_csr])
|
||||
def _to_sparse_csr(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
|
||||
)
|
||||
if not torch.is_tensor(args[0]):
|
||||
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
|
||||
mt = args[0]
|
||||
@ -430,7 +457,9 @@ def _to_sparse_csr(func, *args, **kwargs):
|
||||
|
||||
@register_dispatch_func([torch.ops.aten._to_dense])
|
||||
def _to_dense(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
|
||||
)
|
||||
if not torch.is_tensor(args[0]):
|
||||
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
|
||||
mt = args[0]
|
||||
@ -444,14 +473,18 @@ def _to_dense(func, *args, **kwargs):
|
||||
@register_dispatch_func([torch.ops.aten._indices])
|
||||
def _indices(func, *args, **kwargs):
|
||||
# Assumes data is sparse
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
|
||||
)
|
||||
data = _get_data(args[0]).indices()
|
||||
return MaskedTensor(data, torch.ones_like(data).bool())
|
||||
|
||||
|
||||
@register_dispatch_func([torch.ops.aten._values])
|
||||
def _values(func, *args, **kwargs):
|
||||
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0)
|
||||
_check_args_kwargs_length(
|
||||
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
|
||||
)
|
||||
data = _get_data(args[0]).values()
|
||||
return MaskedTensor(data, torch.ones_like(data).bool())
|
||||
|
||||
|
@ -2,7 +2,14 @@
|
||||
|
||||
import torch
|
||||
|
||||
from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor
|
||||
from .core import (
|
||||
_map_mt_args_kwargs,
|
||||
_masks_match,
|
||||
_tensors_match,
|
||||
_wrap_result,
|
||||
is_masked_tensor,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
@ -79,25 +86,22 @@ def _binary_helper(fn, args, kwargs, inplace):
|
||||
raise ValueError("len(kwargs) must equal 0")
|
||||
for a in args[2:]:
|
||||
if torch.is_tensor(a):
|
||||
raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs")
|
||||
raise TypeError(
|
||||
"MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
|
||||
)
|
||||
|
||||
if not _masks_match(*args[:2]):
|
||||
raise ValueError(
|
||||
"Input masks must match. If you need support for this, please open an issue on Github."
|
||||
)
|
||||
|
||||
data_args, data_kwargs = _map_mt_args_kwargs(
|
||||
args, kwargs, lambda x: x.get_data()
|
||||
)
|
||||
mask_args, mask_kwargs = _map_mt_args_kwargs(
|
||||
args, kwargs, lambda x: x.get_mask()
|
||||
)
|
||||
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
|
||||
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
|
||||
|
||||
args0_layout = data_args[0].layout
|
||||
same_layout = (
|
||||
(torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and
|
||||
(args0_layout == data_args[1].layout)
|
||||
)
|
||||
torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
|
||||
) and (args0_layout == data_args[1].layout)
|
||||
|
||||
if args0_layout == torch.sparse_coo:
|
||||
if same_layout:
|
||||
@ -106,7 +110,9 @@ def _binary_helper(fn, args, kwargs, inplace):
|
||||
"sparse_coo indices must match. If you need support for this, please open an issue on Github."
|
||||
)
|
||||
if data_args[0].size() != data_args[1].size():
|
||||
raise ValueError("input1 and input2 must have the same size for binary functions.")
|
||||
raise ValueError(
|
||||
"input1 and input2 must have the same size for binary functions."
|
||||
)
|
||||
|
||||
data_args[1] = data_args[1].values()
|
||||
|
||||
|
@ -13,7 +13,7 @@ __all__ = [
|
||||
|
||||
|
||||
def is_masked_tensor(a):
|
||||
r""" Returns True if the input is a MaskedTensor, else False
|
||||
r"""Returns True if the input is a MaskedTensor, else False
|
||||
|
||||
Args:
|
||||
a: any input
|
||||
@ -35,7 +35,9 @@ def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
|
||||
if is_masked_tensor(a) or is_masked_tensor(b):
|
||||
raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
|
||||
if a.layout != b.layout:
|
||||
raise ValueError(f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}")
|
||||
raise ValueError(
|
||||
f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}"
|
||||
)
|
||||
|
||||
if a.dtype != b.dtype:
|
||||
b = b.type(a.dtype)
|
||||
@ -108,9 +110,7 @@ def _masked_tensor_str(data, mask, formatter):
|
||||
formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
|
||||
for d in data
|
||||
]
|
||||
max_len = max(
|
||||
8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask)
|
||||
)
|
||||
max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
|
||||
return (
|
||||
"["
|
||||
+ ", ".join(
|
||||
@ -153,13 +153,21 @@ class MaskedTensor(torch.Tensor):
|
||||
kwargs["requires_grad"] = requires_grad
|
||||
kwargs["dispatch_sizes_strides_policy"] = "strides"
|
||||
kwargs["dispatch_layout"] = True
|
||||
warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage "
|
||||
"and will change in the near future. Please open a Github issue "
|
||||
"for features requests and see our documentation on the torch.masked "
|
||||
"module for further information about the project."), UserWarning)
|
||||
warnings.warn(
|
||||
(
|
||||
"The PyTorch API of MaskedTensors is in prototype stage "
|
||||
"and will change in the near future. Please open a Github issue "
|
||||
"for features requests and see our documentation on the torch.masked "
|
||||
"module for further information about the project."
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if data.requires_grad:
|
||||
warnings.warn("It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
|
||||
"To avoid this, you can use data.clone().detach()", UserWarning)
|
||||
warnings.warn(
|
||||
"It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
|
||||
"To avoid this, you can use data.clone().detach()",
|
||||
UserWarning,
|
||||
)
|
||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
def _preprocess_data(self, data, mask):
|
||||
@ -184,17 +192,23 @@ class MaskedTensor(torch.Tensor):
|
||||
data = self._masked_data
|
||||
mask = self.get_mask()
|
||||
if type(data) != type(mask):
|
||||
raise TypeError(f"data and mask must have the same type. Got {type(data)} and {type(mask)}")
|
||||
raise TypeError(
|
||||
f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
|
||||
)
|
||||
if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
|
||||
raise TypeError(f"data layout of {data.layout} is not supported.")
|
||||
if data.layout == torch.sparse_coo:
|
||||
if not _tensors_match(data.indices(), mask.indices(), exact=True):
|
||||
raise ValueError("data and mask are both sparse COO tensors but do not have the same indices.")
|
||||
raise ValueError(
|
||||
"data and mask are both sparse COO tensors but do not have the same indices."
|
||||
)
|
||||
elif data.layout == torch.sparse_csr:
|
||||
if not _tensors_match(
|
||||
data.crow_indices(), mask.crow_indices(), exact=True
|
||||
) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
|
||||
raise ValueError("data and mask are both sparse CSR tensors but do not share either crow or col indices.")
|
||||
raise ValueError(
|
||||
"data and mask are both sparse CSR tensors but do not share either crow or col indices."
|
||||
)
|
||||
if mask.dtype != torch.bool:
|
||||
raise TypeError("mask must have dtype bool.")
|
||||
if not (
|
||||
@ -219,7 +233,8 @@ class MaskedTensor(torch.Tensor):
|
||||
|
||||
@staticmethod
|
||||
def _from_values(data, mask):
|
||||
""" Differentiable constructor for MaskedTensor """
|
||||
"""Differentiable constructor for MaskedTensor"""
|
||||
|
||||
class Constructor(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, data, mask):
|
||||
@ -265,6 +280,7 @@ class MaskedTensor(torch.Tensor):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
|
||||
|
||||
if func in _MASKEDTENSOR_FUNCTION_TABLE:
|
||||
return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
|
||||
|
||||
@ -286,6 +302,7 @@ class MaskedTensor(torch.Tensor):
|
||||
func = func.overloadpacket
|
||||
|
||||
from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
|
||||
|
||||
if func in _MASKEDTENSOR_DISPATCH_TABLE:
|
||||
return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
|
||||
|
||||
|
@ -2,20 +2,21 @@
|
||||
|
||||
from .core import MaskedTensor
|
||||
|
||||
|
||||
__all__ = [
|
||||
"as_masked_tensor",
|
||||
"masked_tensor",
|
||||
]
|
||||
|
||||
|
||||
""""
|
||||
These two factory functions are intended to mirror
|
||||
torch.tensor - guaranteed to be a leaf node
|
||||
torch.as_tensor - differentiable constructor that preserves the autograd history
|
||||
"""
|
||||
# These two factory functions are intended to mirror
|
||||
# torch.tensor - guaranteed to be a leaf node
|
||||
# torch.as_tensor - differentiable constructor that preserves the autograd history
|
||||
|
||||
|
||||
def masked_tensor(data, mask, requires_grad=False):
|
||||
return MaskedTensor(data, mask, requires_grad)
|
||||
|
||||
|
||||
def as_masked_tensor(data, mask):
|
||||
return MaskedTensor._from_values(data, mask)
|
||||
|
@ -10,6 +10,7 @@ import torch
|
||||
|
||||
from .core import _map_mt_args_kwargs, _wrap_result
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
from .core import is_masked_tensor
|
||||
from .creation import as_masked_tensor, masked_tensor
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
@ -159,6 +160,7 @@ NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
|
||||
TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
|
||||
TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
|
||||
|
||||
|
||||
def _is_reduction(fn):
|
||||
return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
|
||||
|
||||
|
@ -4,6 +4,7 @@ import torch
|
||||
|
||||
from .core import _map_mt_args_kwargs, _wrap_result
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
@ -108,18 +109,18 @@ UNARY_NAMES_UNSUPPORTED = [
|
||||
|
||||
def _unary_helper(fn, args, kwargs, inplace):
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. "
|
||||
"If you need support for this, please open an issue on Github.")
|
||||
raise ValueError(
|
||||
"MaskedTensor unary ops require that len(kwargs) == 0. "
|
||||
"If you need support for this, please open an issue on Github."
|
||||
)
|
||||
for a in args[1:]:
|
||||
if torch.is_tensor(a):
|
||||
raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments")
|
||||
raise TypeError(
|
||||
"MaskedTensor unary ops do not support additional Tensor arguments"
|
||||
)
|
||||
|
||||
mask_args, mask_kwargs = _map_mt_args_kwargs(
|
||||
args, kwargs, lambda x: x._masked_mask
|
||||
)
|
||||
data_args, data_kwargs = _map_mt_args_kwargs(
|
||||
args, kwargs, lambda x: x._masked_data
|
||||
)
|
||||
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask)
|
||||
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data)
|
||||
|
||||
if args[0].layout == torch.sparse_coo:
|
||||
data_args[0] = data_args[0].coalesce()
|
||||
|
Reference in New Issue
Block a user