Files
pytorch/test/functorch/functorch_additional_op_db.py
Natalia Gimelshein 34e3930401 fix numpy compatibility for 2d small list indices (#154806)
Will fix #119548 and linked issues once we switch from warning to the new behavior,
but for now, given how much this syntax was used in our test suite, we suspect a silent change will be disruptive.
We will change the behavior after 2.8 branch is cut.
Numpy behavior was changed at least in numpy 1.24 (more than 2 years ago)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154806
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/albanD
2025-06-04 01:58:52 +00:00

816 lines
28 KiB
Python

import itertools
import unittest
from functools import partial
import torch
from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
floating_types,
floating_types_and,
)
from torch.testing._internal.common_methods_invocations import (
DecorateInfo,
OpInfo,
SampleInput,
)
from torch.testing._internal.common_utils import make_tensor
# List of OpInfos that aren't in PyTorch Core yet.
# They are here because we wanted a fast way of writing OpInfos and may not be
# 100% correct (w.r.t. to dtypes and other options).
# TODO: Figure out how to upstream these, delete them when they're upstreamed
additional_op_db = []
# https://github.com/pytorch/pytorch/pull/61068
def sample_inputs_conv2d(
has_bias, self, device, dtype, requires_grad, extra_args=(), groups=1
):
in_ch, out_ch = 6, 4
inp = make_tensor(
(2, in_ch * groups, 7, 5),
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)
weight = make_tensor(
(out_ch * groups, in_ch, 3, 2),
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)
bias = None
if has_bias:
bias = make_tensor(
(out_ch * groups,),
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)
return [SampleInput(inp, args=((weight, bias) + extra_args))]
additional_op_db.extend(
[
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="no_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(sample_inputs_conv2d, False),
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="with_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(sample_inputs_conv2d, True),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="stride_with_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2))),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="stride_no_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(
sample_inputs_conv2d, False, extra_args=((2, 2))
),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="stride_padding_with_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(
sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1))
),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="stride_padding_no_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(
sample_inputs_conv2d, False, extra_args=((2, 2), (1, 1))
),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="strided_padding_dilation_with_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(
sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))
),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="strided_padding_dilation_no_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(
sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))
),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="stride_groups_with_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(
sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 2), groups=2
),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
OpInfo(
"nn.functional.conv2d",
aten_name="conv2d",
variant_test_name="stride_depthwise_with_bias",
supports_autograd=True,
supports_forward_ad=True,
sample_inputs_func=partial(
sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 6), groups=6
),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_types(),
supports_out=False,
),
]
)
# TODO: PyTorch core has a check for if requires_grad=True or not.
# We actually want to test more things for backward here which is why we have our own
def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
def make_input(shape):
return make_tensor(
shape, device=device, dtype=dtype, requires_grad=requires_grad
)
def make_long_input(shape, *, low, high):
return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high)
M = 20
S = 5
def generator():
# 0-D index tensor
idx = make_long_input((), low=0, high=M)
yield SampleInput(
make_input((M, S)),
args=(idx,),
)
# 1-D index tensor
idx = make_long_input((S,), low=0, high=M)
yield SampleInput(
make_input((M, S)),
args=(idx,),
)
# 2-D index tensor
idx = make_long_input((S, S), low=0, high=M)
yield SampleInput(
make_input((M, S)),
args=(idx,),
)
idx = make_long_input((2, 2), low=0, high=S)
idx[0, 0] = 2
idx[1, 1] = 2
yield SampleInput(
make_input((S, S)),
args=(idx,),
kwargs={"padding_idx": 2},
)
idx = make_long_input((2, 2), low=0, high=S)
idx[0, 0] = 4
idx[1, 1] = 4
yield SampleInput(
make_input((S, S)),
args=(idx,),
kwargs={"padding_idx": -1},
)
# Scale the gradient based on the inverse frequency of a particular index.
idx = make_long_input((2, 2), low=0, high=S)
idx[0, 0] = 1
idx[0, 1] = 1
weights = make_input((S, S))
yield SampleInput(
weights,
args=(idx,),
kwargs={"scale_grad_by_freq": True},
)
return list(generator())
additional_op_db.append(
OpInfo(
"nn.functional.embedding",
variant_test_name="functorch",
# We use lambda to reshuffle the positional arguments.
# This is because currently only the `input` field of SampleInput
# is tested in gradient tests.
op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(
idx, weight, **kwargs
),
dtypes=floating_types_and(torch.bfloat16, torch.float16),
sample_inputs_func=sample_inputs_embedding,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
)
)
def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs):
def make_input(shape, requires_grad=requires_grad):
return make_tensor(
shape, device=device, dtype=dtype, requires_grad=requires_grad
)
rhs_requires_grad = kwargs.get("rhs_requires_grad", requires_grad)
S = 5
shapes = ((S, S), (S, S, S), (S, S, S, S))
reductions = ("none", "mean", "sum")
for shape, reduction in itertools.product(shapes, reductions):
yield SampleInput(
make_input(shape),
args=(make_input(shape, requires_grad=rhs_requires_grad),),
kwargs={"reduction": reduction},
)
additional_op_db.append(
OpInfo(
"nn.functional.mse_loss",
variant_test_name="functorch",
sample_inputs_func=sample_inputs_mse_loss,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
dtypes=floating_types_and(torch.float16),
backward_dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
)
)
# TODO: upstream sample inputs to pytorch/pytorch.
# We are more comprehensive.
def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
# Short for "advanced index"
adv_idx = torch.LongTensor([[0, 1], [2, 3]])
S = 5
# self_dim, indices
test_args = [
(3, ([1, 2],)),
(3, (slice(0, 3),)),
(3, ((slice(0, 3), 1),)),
(3, (([0, 2, 3], [1, 3, 3], [0, 0, 2]),)),
(3, (([0, 0, 3], [1, 1, 3], [0, 0, 2]),)),
(3, ((slice(None), slice(None), [0, 3]),)),
(3, ((slice(None), [0, 3], slice(None)),)),
(3, (([0, 3], slice(None), slice(None)),)),
(3, (([0, 3], [1, 2], slice(None)),)),
(
3,
(
[
[0, 3],
],
),
),
(3, (([0, 3], slice(None)),)),
(3, (([0, 3], Ellipsis),)),
(3, (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),)),
(4, ((slice(None), adv_idx, adv_idx, slice(None)),)),
(4, ((slice(None), adv_idx, slice(None), adv_idx),)),
(4, ((adv_idx, slice(None), slice(None), adv_idx),)),
(4, ((slice(None), slice(None), adv_idx, adv_idx),)),
(4, ((Ellipsis, adv_idx, adv_idx),)),
(5, ((slice(None), slice(None), adv_idx, slice(None), adv_idx),)),
(5, ((slice(None), slice(None), adv_idx, adv_idx, slice(None)),)),
(5, ((slice(None), slice(None), adv_idx, None, adv_idx, slice(None)),)),
(6, ((slice(None), slice(None), slice(None), adv_idx, adv_idx),)),
(6, ((slice(None), slice(None), adv_idx, adv_idx, adv_idx),)),
(6, ((slice(None), slice(None), None, adv_idx, adv_idx, adv_idx),)),
]
def get_shape(dim):
return tuple(S + i for i in range(dim))
return tuple(
SampleInput(
make_tensor(
get_shape(self_dim),
device=device,
dtype=dtype,
low=None,
high=None,
requires_grad=requires_grad,
),
args=args,
)
for self_dim, args in test_args
)
# TODO: split PyTorch's __getitem__. The problem is we don't support indexing
# with masks with vmap.
additional_op_db.append(
OpInfo(
"__getitem__",
variant_test_name="functorch",
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_inplace_autograd=False,
supports_scripting=False,
op=torch.Tensor.__getitem__,
assert_jit_shape_analysis=False, # TODO: support index.Tensor()
supports_forward_ad=True,
sample_inputs_func=sample_inputs_getitem,
)
)
# Turns out at::index_put is different from torch.index_put...
# TODO: figure out how to upstream this
def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
inputs = []
adv_idx = torch.LongTensor([[0, 1], [2, 3]])
# self_shape, indices
additional = [
((5, 6, 7, 8), (None, adv_idx, adv_idx, None)),
((5, 6, 7, 8), (None, adv_idx, None, adv_idx)),
((5, 6, 7, 8), (adv_idx, None, None, adv_idx)),
((5, 6, 7, 8), (None, None, adv_idx, adv_idx)),
((5, 6, 7, 8, 9), (None, None, adv_idx, None, adv_idx)),
((5, 6, 7, 8, 9), (None, None, adv_idx, adv_idx, None)),
((5, 6, 7, 8, 9, 10), (None, None, None, adv_idx, adv_idx)),
((5, 6, 7, 8, 9, 10), (None, None, adv_idx, adv_idx, adv_idx)),
]
for self_shape, indices in additional:
for broadcast_value in [False, True]:
inp = make_arg(self_shape)
tmp_indices = tuple(
[slice(None) if idx is None else idx for idx in indices]
)
values_shape = inp[tmp_indices].shape
if broadcast_value:
values_shape = values_shape[3:]
values = make_arg(values_shape)
inputs.append(SampleInput(inp, args=(tuple(indices), values)))
return inputs
def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
make_idx = partial(
make_tensor, dtype=torch.long, device=device, requires_grad=False
)
S = 5
inputs = []
for accumulate in [False, True]:
# putting vectors at indexed locations
inputs.append(
SampleInput(
make_arg((S, S)),
args=((make_idx((2,), low=0, high=4),), make_arg((2, S))),
kwargs=dict(accumulate=accumulate),
)
)
# putting multi-dim tensors at indexed locations
inputs.append(
SampleInput(
make_arg((S, S, 2)),
args=((make_idx((3,), low=0, high=4),), make_arg((3, S, 2))),
kwargs=dict(accumulate=accumulate),
)
)
# value with size `0` dim
inputs.append(
SampleInput(
make_arg((S, 0)),
args=((make_idx((3,), low=0, high=4),), make_arg((3, 0))),
kwargs=dict(accumulate=accumulate),
)
)
# scalar value
inputs.append(
SampleInput(
make_arg((S,)),
args=((make_idx((), low=0, high=S),), make_arg(())),
kwargs=dict(accumulate=accumulate),
)
)
# cuda and accumulate don't work well
# Reference: https://github.com/pytorch/pytorch/issues/72053
if not accumulate and device == "cuda":
# Broadcast `values`
inputs.append(
SampleInput(
make_arg((S, S)),
args=((make_idx((2,), low=0, high=S),), make_arg((S,))),
kwargs=dict(accumulate=accumulate),
)
)
return inputs
additional_op_db.append(
OpInfo(
"index_put",
variant_test_name="functorch",
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_index_put,
supports_forward_ad=True,
)
)
additional_op_db.append(
OpInfo(
"ops.aten.index_put",
variant_test_name="functorch",
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_aten_index_put,
supports_forward_ad=True,
)
)
def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs):
S = 3
make_arg = partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10))
yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10))
yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10))
yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10))
yield SampleInput(
make_arg((S,)),
args=(torch.randn(S, S, device=device) > 0, 10),
broadcasts_input=True,
)
additional_op_db.append(
OpInfo(
"masked_fill",
variant_test_name="functorch_Scalar_only",
dtypes=all_types_and_complex_and(
torch.bool, torch.half, torch.bfloat16, torch.chalf
),
sample_inputs_func=sample_inputs_masked_fill,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
supports_out=False,
)
)
def sample_inputs_new_zeros_with_same_feature_meta(
op_info, device, dtype, requires_grad, **kwargs
):
make_arg = partial(
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
matrix = [
# tangent, base, num_tangent_bdims
([5], [2, 3], 0),
([2, 3], [2, 3], 0),
([5], [2], 0),
([1, 0, 2], [1, 2], 0),
([], [1, 2], 0),
([8, 7, 5], [2, 3, 11], 1),
([6, 7, 5], [2, 3, 4], 2),
([6, 4], [3], 2),
]
results = []
for tangent_shape, base_shape, num_tangent_bdims in matrix:
tangent = make_arg(tangent_shape)
base = make_arg(base_shape)
results.append(
SampleInput(
tangent,
args=(base,),
kwargs=dict(self_num_batch_dims=num_tangent_bdims),
)
)
return results
additional_op_db.append(
OpInfo(
"ops.aten._new_zeros_with_same_feature_meta",
variant_test_name="functorchonly",
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_autograd=False,
supports_forward_ad=False,
sample_inputs_func=sample_inputs_new_zeros_with_same_feature_meta,
)
)
def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
shapes = ((), (2, 3))
memory_format_options = [None, torch.contiguous_format]
for shape, memory_format in itertools.product(shapes, memory_format_options):
yield SampleInput(
make_arg(shape),
kwargs={"memory_format": memory_format} if memory_format else {},
)
additional_op_db.extend(
[
OpInfo(
"bfloat16",
op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
skips=(
# autograd tests don't handle operators that change dtype
DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
DecorateInfo(
unittest.skip("Skipped!"), "TestNNCOpInfo", "test_nnc_correctness"
),
),
),
OpInfo(
"bool",
op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
supports_autograd=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"byte",
op=lambda x, *args, **kwargs: x.byte(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
# The autograd test runner cannot handle functions that change dtype
supports_autograd=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"char",
op=lambda x, *args, **kwargs: x.char(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
# The autograd test runner cannot handle functions that change dtype
supports_autograd=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"double",
op=lambda x, *args, **kwargs: x.double(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"float",
op=lambda x, *args, **kwargs: x.float(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
skips=(
# autograd tests don't handle operators that change dtype
DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"half",
op=lambda x, *args, **kwargs: x.half(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
skips=(
# autograd tests don't handle operators that change dtype
DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"int",
op=lambda x, *args, **kwargs: x.int(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
supports_autograd=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"long",
op=lambda x, *args, **kwargs: x.long(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
supports_autograd=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
OpInfo(
"short",
op=lambda x, *args, **kwargs: x.short(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
variant_test_name="functorch_no_channels_last",
sample_inputs_func=sample_inputs_conversion,
supports_autograd=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
),
),
]
)