mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Strided masked var. (#68738)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68738 Test Plan: Imported from OSS Reviewed By: davidberard98 Differential Revision: D32767155 Pulled By: cpuhrsch fbshipit-source-id: a5c095103405fbfc28b9f4fd624bdbbc45e7f715
This commit is contained in:
committed by
Facebook GitHub Bot
parent
291e56eda4
commit
370d0afc1b
@ -150,6 +150,7 @@ def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
|
||||
|
||||
reference_functions = dict(
|
||||
norm=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.linalg.vector_norm, *args, **dict(kwargs, dim_position=1)),
|
||||
var=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.var, *args, **dict(kwargs, dim_position=0)),
|
||||
softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
|
||||
log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
|
||||
softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
|
||||
@ -167,10 +168,14 @@ class TestMasked(TestCase):
|
||||
@suppress_warnings
|
||||
@ops(masked_ops_with_references)
|
||||
def test_reference_masked(self, device, dtype, op):
|
||||
ref_op = reference_functions[op.name.rsplit('.', 1)[-1]]
|
||||
op_name = op.name.rsplit('.', 1)[-1]
|
||||
ref_op = reference_functions[op_name]
|
||||
sample_inputs = op.sample_inputs(device, dtype)
|
||||
for sample_input in sample_inputs:
|
||||
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
|
||||
if op_name == 'var' and not (t_inp.dtype.is_floating_point or t_inp.dtype.is_complex):
|
||||
# torch.var does not support integer inputs
|
||||
continue
|
||||
actual = op.op(t_inp, *t_args, **t_kwargs)
|
||||
expected = ref_op(t_inp, *t_args, **t_kwargs)
|
||||
outmask = torch._masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Tuple, List, Union, Any
|
||||
|
||||
import torch
|
||||
@ -141,6 +143,7 @@ Example::
|
||||
amax=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
|
||||
mean=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
|
||||
norm=(('ord', 'dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
|
||||
var=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
|
||||
softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
|
||||
log_softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
|
||||
softmin=(('dim__as_int',), ('dtype=None', 'mask=None')),
|
||||
@ -159,6 +162,9 @@ ord (int, float, optional): the order of vector norm. Default: 2.
|
||||
ord__required='''\
|
||||
ord (int, float): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.''',
|
||||
unbiased='''\
|
||||
unbiased (bool): when True, use Bessel’s correction, otherwise, compute
|
||||
the uncorrected sample variance.''',
|
||||
eps='''\
|
||||
eps (float, optional): small value to avoid division by zero. Default: {default}.''',
|
||||
keepdim='''\
|
||||
@ -199,7 +205,8 @@ defined as ``x[i]/max(norm(x, p), eps)``.''')
|
||||
amax='maximum',
|
||||
amin='minimum',
|
||||
mean='mean',
|
||||
norm='norm')
|
||||
norm='norm',
|
||||
var='variance')
|
||||
|
||||
normalization_names = dict(
|
||||
softmax='softmax',
|
||||
@ -219,6 +226,8 @@ defined as ``x[i]/max(norm(x, p), eps)``.''')
|
||||
if func.__name__ in {'norm', 'normalize'}:
|
||||
example_args = (2.0, example_dim)
|
||||
example_input = example_input.to(dtype=torch.float32)
|
||||
elif func.__name__ in {'var'}:
|
||||
example_args = (example_dim, False)
|
||||
else:
|
||||
example_args = (example_dim,)
|
||||
|
||||
@ -340,6 +349,8 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
|
||||
assert torch.is_floating_point(input), input.dtype
|
||||
return torch.tensor(torch.inf, dtype=dtype, device=device)
|
||||
return torch.tensor(0, dtype=dtype, device=device)
|
||||
elif op_name == 'var':
|
||||
return None
|
||||
raise NotImplementedError(f'identity of {op_name} on {dtype} input')
|
||||
|
||||
|
||||
@ -383,7 +394,7 @@ def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
|
||||
"""Return output mask of masked operation applied to given arguments.
|
||||
"""
|
||||
if callable(op):
|
||||
is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm'}
|
||||
is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm', 'var'}
|
||||
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize'}
|
||||
if is_reduction:
|
||||
if op.__name__ == 'norm':
|
||||
@ -575,6 +586,53 @@ reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
|
||||
raise ValueError(f'masked norm expects strided tensor (got {input.layout} tensor)')
|
||||
|
||||
|
||||
@_apply_docstring_templates
|
||||
def var(input: Tensor,
|
||||
dim: DimOrDims = None,
|
||||
unbiased: Optional[bool] = False,
|
||||
*,
|
||||
keepdim: Optional[bool] = False,
|
||||
dtype: Optional[DType] = None,
|
||||
mask: Optional[Tensor] = None) -> Tensor:
|
||||
"""\
|
||||
{reduction_signature}
|
||||
|
||||
{reduction_descr}
|
||||
|
||||
The identity value of sample variance operation is undefined. The
|
||||
elements of output tensor with strided layout, that correspond to
|
||||
fully masked-out elements, have ``nan`` values.
|
||||
|
||||
{reduction_args}
|
||||
|
||||
{reduction_example}"""
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if not (dtype.is_floating_point or dtype.is_complex):
|
||||
dtype = torch.float32
|
||||
compute_dtype = dtype
|
||||
if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
|
||||
compute_dtype = torch.float32
|
||||
if input.layout == torch.strided:
|
||||
inmask = _input_mask(input, mask=mask)
|
||||
count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=True, mask=inmask)
|
||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
||||
# TODO: replace torch.subtract/divide/square/maximum with
|
||||
# masked subtract/divide/square/maximum when these will be
|
||||
# available.
|
||||
sample_mean = torch.divide(sample_total, count)
|
||||
x = torch.subtract(input, sample_mean)
|
||||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask)
|
||||
if not keepdim:
|
||||
count = count.reshape(total.shape)
|
||||
if unbiased:
|
||||
count = torch.subtract(count, 1)
|
||||
count = torch.maximum(count, count.new_zeros([]))
|
||||
return torch.divide(total, count).to(dtype=dtype)
|
||||
else:
|
||||
raise ValueError(f'masked var expects strided tensor (got {input.layout} tensor)')
|
||||
|
||||
|
||||
@_apply_docstring_templates
|
||||
def softmax(input: Tensor,
|
||||
dim: int,
|
||||
|
||||
@ -1019,6 +1019,35 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
return inputs
|
||||
|
||||
|
||||
def sample_inputs_masked_var(op_info, device, dtype, requires_grad, **kwargs):
|
||||
"""Sample inputs for masked var.
|
||||
"""
|
||||
inputs: List[SampleInput] = []
|
||||
for unbiased in [False, True]:
|
||||
for sample_input in sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
|
||||
if sample_input.args:
|
||||
dim = sample_input.args[0]
|
||||
sample_input_args = sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
|
||||
sample_input_kwargs = sample_input.kwargs.copy()
|
||||
else:
|
||||
dim = sample_input.kwargs.get('dim')
|
||||
sample_input_args = sample_input.args
|
||||
sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
|
||||
if requires_grad:
|
||||
inmask = torch._masked._input_mask(sample_input.input, *sample_input_args, **sample_input_kwargs)
|
||||
orig_count = torch._masked.sum(inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
|
||||
dim, keepdim=True, mask=inmask)
|
||||
if orig_count.min() <= int(unbiased):
|
||||
# Skip samples that lead to singularities in var
|
||||
# computation resulting nan values both in var and
|
||||
# autograd output that test_grad_fn cannot handle
|
||||
# correctly.
|
||||
continue
|
||||
inputs.append(SampleInput(sample_input.input.detach().clone().requires_grad_(requires_grad),
|
||||
args=sample_input_args, kwargs=sample_input_kwargs))
|
||||
return inputs
|
||||
|
||||
|
||||
# NOTE [Reductions]:
|
||||
#
|
||||
# For testing purposes, we relax the definition of a reduction operator
|
||||
@ -7917,6 +7946,11 @@ def reference_reduction_numpy(f, supports_keepdims=True):
|
||||
identity = identity.cpu()
|
||||
kwargs['initial'] = identity.numpy()
|
||||
|
||||
if 'unbiased' in keys:
|
||||
unbiased = kwargs.pop('unbiased')
|
||||
if unbiased is not None:
|
||||
kwargs['ddof'] = int(unbiased)
|
||||
|
||||
result = f(x, *args, **kwargs)
|
||||
|
||||
# Unsqueeze reduced dimensions if NumPy does not support keepdims
|
||||
@ -13751,6 +13785,32 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_masked_norm,
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation
|
||||
),
|
||||
ReductionOpInfo(
|
||||
'_masked.var',
|
||||
ref=reference_reduction_numpy(np.var) if np.lib.NumpyVersion(np.__version__) >= '1.20.2' else None,
|
||||
method_variant=None,
|
||||
nan_policy='propagate',
|
||||
supports_out=False,
|
||||
promotes_int_to_float=True,
|
||||
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
skips=(
|
||||
# FIXME: sum reduces all dimensions when dim=[]
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
|
||||
# RuntimeError: undefined value tensor
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
decorators=[
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestReductions', 'test_reference_masked'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestReductions', 'test_ref_small_input'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestMasked', 'test_reference_masked'),
|
||||
],
|
||||
sample_inputs_func=sample_inputs_masked_var,
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation
|
||||
),
|
||||
OpInfo(
|
||||
'_masked.softmax',
|
||||
method_variant=None,
|
||||
|
||||
Reference in New Issue
Block a user