Strided masked var. (#68738)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68738

Test Plan: Imported from OSS

Reviewed By: davidberard98

Differential Revision: D32767155

Pulled By: cpuhrsch

fbshipit-source-id: a5c095103405fbfc28b9f4fd624bdbbc45e7f715
This commit is contained in:
Pearu Peterson
2021-12-01 19:17:33 -08:00
committed by Facebook GitHub Bot
parent 291e56eda4
commit 370d0afc1b
3 changed files with 126 additions and 3 deletions

View File

@ -150,6 +150,7 @@ def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
reference_functions = dict(
norm=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.linalg.vector_norm, *args, **dict(kwargs, dim_position=1)),
var=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.var, *args, **dict(kwargs, dim_position=0)),
softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
@ -167,10 +168,14 @@ class TestMasked(TestCase):
@suppress_warnings
@ops(masked_ops_with_references)
def test_reference_masked(self, device, dtype, op):
ref_op = reference_functions[op.name.rsplit('.', 1)[-1]]
op_name = op.name.rsplit('.', 1)[-1]
ref_op = reference_functions[op_name]
sample_inputs = op.sample_inputs(device, dtype)
for sample_input in sample_inputs:
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
if op_name == 'var' and not (t_inp.dtype.is_floating_point or t_inp.dtype.is_complex):
# torch.var does not support integer inputs
continue
actual = op.op(t_inp, *t_args, **t_kwargs)
expected = ref_op(t_inp, *t_args, **t_kwargs)
outmask = torch._masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)

View File

@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
from typing import Optional, Tuple, List, Union, Any
import torch
@ -141,6 +143,7 @@ Example::
amax=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
mean=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
norm=(('ord', 'dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
var=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
log_softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
softmin=(('dim__as_int',), ('dtype=None', 'mask=None')),
@ -159,6 +162,9 @@ ord (int, float, optional): the order of vector norm. Default: 2.
ord__required='''\
ord (int, float): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.''',
unbiased='''\
unbiased (bool): when True, use Bessels correction, otherwise, compute
the uncorrected sample variance.''',
eps='''\
eps (float, optional): small value to avoid division by zero. Default: {default}.''',
keepdim='''\
@ -199,7 +205,8 @@ defined as ``x[i]/max(norm(x, p), eps)``.''')
amax='maximum',
amin='minimum',
mean='mean',
norm='norm')
norm='norm',
var='variance')
normalization_names = dict(
softmax='softmax',
@ -219,6 +226,8 @@ defined as ``x[i]/max(norm(x, p), eps)``.''')
if func.__name__ in {'norm', 'normalize'}:
example_args = (2.0, example_dim)
example_input = example_input.to(dtype=torch.float32)
elif func.__name__ in {'var'}:
example_args = (example_dim, False)
else:
example_args = (example_dim,)
@ -340,6 +349,8 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
assert torch.is_floating_point(input), input.dtype
return torch.tensor(torch.inf, dtype=dtype, device=device)
return torch.tensor(0, dtype=dtype, device=device)
elif op_name == 'var':
return None
raise NotImplementedError(f'identity of {op_name} on {dtype} input')
@ -383,7 +394,7 @@ def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
"""Return output mask of masked operation applied to given arguments.
"""
if callable(op):
is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm'}
is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm', 'var'}
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize'}
if is_reduction:
if op.__name__ == 'norm':
@ -575,6 +586,53 @@ reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
raise ValueError(f'masked norm expects strided tensor (got {input.layout} tensor)')
@_apply_docstring_templates
def var(input: Tensor,
dim: DimOrDims = None,
unbiased: Optional[bool] = False,
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None) -> Tensor:
"""\
{reduction_signature}
{reduction_descr}
The identity value of sample variance operation is undefined. The
elements of output tensor with strided layout, that correspond to
fully masked-out elements, have ``nan`` values.
{reduction_args}
{reduction_example}"""
if dtype is None:
dtype = input.dtype
if not (dtype.is_floating_point or dtype.is_complex):
dtype = torch.float32
compute_dtype = dtype
if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
compute_dtype = torch.float32
if input.layout == torch.strided:
inmask = _input_mask(input, mask=mask)
count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=True, mask=inmask)
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
# TODO: replace torch.subtract/divide/square/maximum with
# masked subtract/divide/square/maximum when these will be
# available.
sample_mean = torch.divide(sample_total, count)
x = torch.subtract(input, sample_mean)
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask)
if not keepdim:
count = count.reshape(total.shape)
if unbiased:
count = torch.subtract(count, 1)
count = torch.maximum(count, count.new_zeros([]))
return torch.divide(total, count).to(dtype=dtype)
else:
raise ValueError(f'masked var expects strided tensor (got {input.layout} tensor)')
@_apply_docstring_templates
def softmax(input: Tensor,
dim: int,

View File

@ -1019,6 +1019,35 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
return inputs
def sample_inputs_masked_var(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked var.
"""
inputs: List[SampleInput] = []
for unbiased in [False, True]:
for sample_input in sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
if sample_input.args:
dim = sample_input.args[0]
sample_input_args = sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
sample_input_kwargs = sample_input.kwargs.copy()
else:
dim = sample_input.kwargs.get('dim')
sample_input_args = sample_input.args
sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
if requires_grad:
inmask = torch._masked._input_mask(sample_input.input, *sample_input_args, **sample_input_kwargs)
orig_count = torch._masked.sum(inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
dim, keepdim=True, mask=inmask)
if orig_count.min() <= int(unbiased):
# Skip samples that lead to singularities in var
# computation resulting nan values both in var and
# autograd output that test_grad_fn cannot handle
# correctly.
continue
inputs.append(SampleInput(sample_input.input.detach().clone().requires_grad_(requires_grad),
args=sample_input_args, kwargs=sample_input_kwargs))
return inputs
# NOTE [Reductions]:
#
# For testing purposes, we relax the definition of a reduction operator
@ -7917,6 +7946,11 @@ def reference_reduction_numpy(f, supports_keepdims=True):
identity = identity.cpu()
kwargs['initial'] = identity.numpy()
if 'unbiased' in keys:
unbiased = kwargs.pop('unbiased')
if unbiased is not None:
kwargs['ddof'] = int(unbiased)
result = f(x, *args, **kwargs)
# Unsqueeze reduced dimensions if NumPy does not support keepdims
@ -13751,6 +13785,32 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_masked_norm,
gradcheck_wrapper=gradcheck_wrapper_masked_operation
),
ReductionOpInfo(
'_masked.var',
ref=reference_reduction_numpy(np.var) if np.lib.NumpyVersion(np.__version__) >= '1.20.2' else None,
method_variant=None,
nan_policy='propagate',
supports_out=False,
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
skips=(
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
# RuntimeError: undefined value tensor
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
decorators=[
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestReductions', 'test_reference_masked'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestReductions', 'test_ref_small_input'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestMasked', 'test_reference_masked'),
],
sample_inputs_func=sample_inputs_masked_var,
gradcheck_wrapper=gradcheck_wrapper_masked_operation
),
OpInfo(
'_masked.softmax',
method_variant=None,