OpInfo: Sample input cleanup (4/n) (#86324)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86324
Approved by: https://github.com/mruberry
This commit is contained in:
Peter Bell
2022-10-19 17:00:52 +01:00
committed by PyTorch MergeBot
parent c141f28b64
commit 6eeeb88172
5 changed files with 64 additions and 56 deletions

View File

@ -648,8 +648,6 @@ class TestOperators(TestCase):
xfail("take"), # vmap: inplace into a regular tensor
xfail("to"), # rank 4 tensor for channels_last
xfail("view_as_complex"), # RuntimeError: Tensor must have a last dimension with stride 1
xfail("masked.softmax", device_type='cuda'), # Mismatch in values!
xfail("masked.softmin", device_type='cuda'), # Mismatch in values!
# got a batched tensor as input while the running_mean or running_var,
# which will be updated in place, were not batched.
xfail("nn.functional.batch_norm", 'without_cudnn'),
@ -1626,6 +1624,8 @@ class TestOperators(TestCase):
{torch.float32: tol(atol=5e-04, rtol=9e-03)}, device_type='cuda'),
tol1('linalg.householder_product',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cpu'),
tol1('linalg.multi_dot',
{torch.float32: tol(atol=2e-04, rtol=1e-04)}, device_type='cuda'),
tol2('linalg.pinv', 'hermitian',
{torch.float32: tol(atol=5e-06, rtol=5e-06)}),
))

View File

@ -200,6 +200,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
# Exceeds tolerances on CUDA, likely due to fma
(torch.float32, torch.ops.aten.mv.default) : (1e-5, 3e-5),
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 1e-6),
(torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
}
if (test_dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
@ -294,8 +295,11 @@ CROSS_REF_EXCLUDE_SET = {
}
CROSS_REF_BACKWARD_EXCLUDE_SET = {
# Backward formula is not as precise as the custom CUDA kernel
# Decomposed backward formula is not as precise
("cuda", torch.float16, "nn.functional.embedding"),
("cuda", torch.bfloat16, "nn.functional.embedding"),
("cpu", torch.bfloat16, "nn.functional.hardswish"),
("cuda", torch.float16, "nn.functional.cross_entropy"),
}
all_decomposed = set()

View File

@ -3513,9 +3513,9 @@ def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kw
def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
N = 5
# make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
tensors = [SampleInput(make_tensor((N * 2, N * 2), device=device, dtype=dtype,
requires_grad=requires_grad, low=-5, high=5)) for _ in range(1, N)]
return tensors
make_arg = partial(make_tensor, device=device, dtype=dtype,
requires_grad=requires_grad, low=-5, high=5)
return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N))
def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
features_options = [[3, 4], [8, 8]]
@ -4692,21 +4692,19 @@ def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype,
requires_grad=requires_grad)
return [
SampleInput(tensor_nd()),
SampleInput(tensor_nd(), dim=1),
SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True),
SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True),
SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False),
yield SampleInput(tensor_nd())
yield SampleInput(tensor_nd(), dim=1)
yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True)
yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True)
yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False)
SampleInput(tensor_nd(), dim=(1,), correction=S // 2),
SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True),
yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2)
yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True)
yield SampleInput(tensor_nd(), dim=None, correction=None)
# Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
SampleInput(tensor_nd(), True),
SampleInput(tensor_nd(), False),
SampleInput(tensor_nd(), dim=None, correction=None),
]
# Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
yield SampleInput(tensor_nd(), True)
yield SampleInput(tensor_nd(), False)
def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs):
@ -5253,7 +5251,6 @@ def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs)
(shape, dict(ignore_index=1)),
]
sample_inputs = []
for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)):
input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad)
@ -5283,9 +5280,7 @@ def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs)
# make sure at least one item in target is not ignored
target[0] = random.sample(set(range(num_classes)) - {kwargs["ignore_index"]}, 1)[0]
sample_inputs.append(SampleInput(input, args=(target,), kwargs=kwargs))
return sample_inputs
yield SampleInput(input, target, **kwargs)
def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs):
@ -5391,6 +5386,8 @@ def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs):
yield SampleInput(make_arg((S, S, S)))
def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, low=None,
high=None, requires_grad=requires_grad)
test_cases = (((L,), (L,)),
((S, M), (M,)),
((M,), (M, S)),
@ -5405,15 +5402,13 @@ def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False
((S, S, M, M), (S, S, M, S)),
((S, S, M, M), (M,)),
((M,), (S, S, M, S)))
sample_inputs = []
for lhs_shape, rhs_shape in test_cases:
lhs = make_tensor(lhs_shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
rhs = make_tensor(rhs_shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
lhs = make_arg(lhs_shape)
rhs = make_arg(rhs_shape)
if not is_rmatmul:
sample_inputs.append(SampleInput(lhs, args=(rhs,)))
yield SampleInput(lhs, rhs)
else:
sample_inputs.append(SampleInput(rhs, args=(lhs,)))
return tuple(sample_inputs)
yield SampleInput(rhs, lhs)
def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype,
@ -9961,7 +9956,11 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
check_batched_forward_grad=False,
supports_fwgrad_bwgrad=True),
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
"TestDecomp", "test_comprehensive", device_type="cuda"),
)),
OpInfo('std_mean',
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_std_var,
@ -9969,7 +9968,11 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
check_batched_forward_grad=False,
supports_fwgrad_bwgrad=True),
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
"TestDecomp", "test_comprehensive", device_type="cuda"),
)),
OpInfo('meshgrid',
variant_test_name='variadic_tensors',
ref=np.meshgrid,
@ -10738,6 +10741,10 @@ op_db: List[OpInfo] = [
toleranceOverride({torch.chalf: tol(atol=1e-3, rtol=1e-3)}),
'TestCudaFuserOpInfo', 'test_nvfuser_correctness',
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
),
),
skips=(
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
@ -15440,6 +15447,10 @@ op_db: List[OpInfo] = [
"test_out",
device_type="meta",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
),
),
),
OpInfo('t',

View File

@ -50,9 +50,9 @@ def sample_inputs_softmax_variant(
if torch.device(device).type != "xla":
cases.append(((), (0,)))
return [
return (
SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
]
)
def _generate_masked_op_mask(input_shape, device, **kwargs):
@ -281,24 +281,18 @@ def sample_inputs_masked_softmax(
same shape as input or a shape that is broadcastable to input
shape.
"""
inputs: List[SampleInput] = []
for sample_input in sample_inputs_softmax_variant(
op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs
):
for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
yield SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
*sample_input.args,
mask=mask,
**sample_input.kwargs,
)
inputs.append(
SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
)
return inputs
def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs):
@ -325,16 +319,12 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
continue
dim = sample_input_kwargs.pop("dim")
sample_input_args = (dim,)
inputs.append(
SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
yield SampleInput(
sample_input.input.clone().requires_grad_(requires_grad),
*sample_input_args,
**sample_input_kwargs,
)
return inputs
def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked logaddexp."""
@ -573,6 +563,12 @@ op_db: List[OpInfo] = [
"test_backward",
device_type="cuda",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=2e-3, rtol=2e-3)}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda",
),
),
# Can reuse the same inputs; dim is required in both
sample_inputs_func=sample_inputs_masked_cumops,

View File

@ -318,7 +318,6 @@ def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwar
[2, 4, 3, 5, 3, 2],
]
result = []
for sizes in test_cases:
tensors = []
for size in zip(sizes[:-1], sizes[1:]):
@ -326,9 +325,7 @@ def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwar
size, dtype=dtype, device=device, requires_grad=requires_grad
)
tensors.append(t)
result.append(SampleInput(tensors))
return result
yield SampleInput(tensors)
def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs):