mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Will fix #119548 and linked issues once we switch from warning to the new behavior, but for now, given how much this syntax was used in our test suite, we suspect a silent change will be disruptive. We will change the behavior after 2.8 branch is cut. Numpy behavior was changed at least in numpy 1.24 (more than 2 years ago) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154806 Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/albanD
816 lines
28 KiB
Python
816 lines
28 KiB
Python
import itertools
|
|
import unittest
|
|
from functools import partial
|
|
|
|
import torch
|
|
from torch.testing._internal.common_dtype import (
|
|
all_types_and_complex_and,
|
|
floating_types,
|
|
floating_types_and,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import (
|
|
DecorateInfo,
|
|
OpInfo,
|
|
SampleInput,
|
|
)
|
|
from torch.testing._internal.common_utils import make_tensor
|
|
|
|
|
|
# List of OpInfos that aren't in PyTorch Core yet.
|
|
# They are here because we wanted a fast way of writing OpInfos and may not be
|
|
# 100% correct (w.r.t. to dtypes and other options).
|
|
# TODO: Figure out how to upstream these, delete them when they're upstreamed
|
|
|
|
additional_op_db = []
|
|
|
|
# https://github.com/pytorch/pytorch/pull/61068
|
|
|
|
|
|
def sample_inputs_conv2d(
|
|
has_bias, self, device, dtype, requires_grad, extra_args=(), groups=1
|
|
):
|
|
in_ch, out_ch = 6, 4
|
|
inp = make_tensor(
|
|
(2, in_ch * groups, 7, 5),
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=requires_grad,
|
|
low=-1,
|
|
high=1,
|
|
)
|
|
weight = make_tensor(
|
|
(out_ch * groups, in_ch, 3, 2),
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=requires_grad,
|
|
low=-1,
|
|
high=1,
|
|
)
|
|
bias = None
|
|
if has_bias:
|
|
bias = make_tensor(
|
|
(out_ch * groups,),
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=requires_grad,
|
|
low=-1,
|
|
high=1,
|
|
)
|
|
return [SampleInput(inp, args=((weight, bias) + extra_args))]
|
|
|
|
|
|
additional_op_db.extend(
|
|
[
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="no_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(sample_inputs_conv2d, False),
|
|
dtypes=floating_types(),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="with_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(sample_inputs_conv2d, True),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="stride_with_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(sample_inputs_conv2d, True, extra_args=((2, 2))),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="stride_no_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(
|
|
sample_inputs_conv2d, False, extra_args=((2, 2))
|
|
),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="stride_padding_with_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(
|
|
sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1))
|
|
),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="stride_padding_no_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(
|
|
sample_inputs_conv2d, False, extra_args=((2, 2), (1, 1))
|
|
),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="strided_padding_dilation_with_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(
|
|
sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))
|
|
),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="strided_padding_dilation_no_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(
|
|
sample_inputs_conv2d, True, extra_args=((2, 2), (1, 1), (2, 2))
|
|
),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="stride_groups_with_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(
|
|
sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 2), groups=2
|
|
),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
"nn.functional.conv2d",
|
|
aten_name="conv2d",
|
|
variant_test_name="stride_depthwise_with_bias",
|
|
supports_autograd=True,
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=partial(
|
|
sample_inputs_conv2d, True, extra_args=((2, 3), 0, 1, 6), groups=6
|
|
),
|
|
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
|
dtypes=floating_types(),
|
|
supports_out=False,
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
# TODO: PyTorch core has a check for if requires_grad=True or not.
|
|
# We actually want to test more things for backward here which is why we have our own
|
|
def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
|
|
def make_input(shape):
|
|
return make_tensor(
|
|
shape, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
|
|
def make_long_input(shape, *, low, high):
|
|
return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high)
|
|
|
|
M = 20
|
|
S = 5
|
|
|
|
def generator():
|
|
# 0-D index tensor
|
|
idx = make_long_input((), low=0, high=M)
|
|
yield SampleInput(
|
|
make_input((M, S)),
|
|
args=(idx,),
|
|
)
|
|
|
|
# 1-D index tensor
|
|
idx = make_long_input((S,), low=0, high=M)
|
|
yield SampleInput(
|
|
make_input((M, S)),
|
|
args=(idx,),
|
|
)
|
|
|
|
# 2-D index tensor
|
|
idx = make_long_input((S, S), low=0, high=M)
|
|
yield SampleInput(
|
|
make_input((M, S)),
|
|
args=(idx,),
|
|
)
|
|
|
|
idx = make_long_input((2, 2), low=0, high=S)
|
|
idx[0, 0] = 2
|
|
idx[1, 1] = 2
|
|
yield SampleInput(
|
|
make_input((S, S)),
|
|
args=(idx,),
|
|
kwargs={"padding_idx": 2},
|
|
)
|
|
|
|
idx = make_long_input((2, 2), low=0, high=S)
|
|
idx[0, 0] = 4
|
|
idx[1, 1] = 4
|
|
yield SampleInput(
|
|
make_input((S, S)),
|
|
args=(idx,),
|
|
kwargs={"padding_idx": -1},
|
|
)
|
|
|
|
# Scale the gradient based on the inverse frequency of a particular index.
|
|
idx = make_long_input((2, 2), low=0, high=S)
|
|
idx[0, 0] = 1
|
|
idx[0, 1] = 1
|
|
weights = make_input((S, S))
|
|
yield SampleInput(
|
|
weights,
|
|
args=(idx,),
|
|
kwargs={"scale_grad_by_freq": True},
|
|
)
|
|
|
|
return list(generator())
|
|
|
|
|
|
additional_op_db.append(
|
|
OpInfo(
|
|
"nn.functional.embedding",
|
|
variant_test_name="functorch",
|
|
# We use lambda to reshuffle the positional arguments.
|
|
# This is because currently only the `input` field of SampleInput
|
|
# is tested in gradient tests.
|
|
op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(
|
|
idx, weight, **kwargs
|
|
),
|
|
dtypes=floating_types_and(torch.bfloat16, torch.float16),
|
|
sample_inputs_func=sample_inputs_embedding,
|
|
supports_forward_ad=True,
|
|
supports_fwgrad_bwgrad=True,
|
|
supports_out=False,
|
|
)
|
|
)
|
|
|
|
|
|
def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs):
|
|
def make_input(shape, requires_grad=requires_grad):
|
|
return make_tensor(
|
|
shape, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
|
|
rhs_requires_grad = kwargs.get("rhs_requires_grad", requires_grad)
|
|
S = 5
|
|
|
|
shapes = ((S, S), (S, S, S), (S, S, S, S))
|
|
reductions = ("none", "mean", "sum")
|
|
|
|
for shape, reduction in itertools.product(shapes, reductions):
|
|
yield SampleInput(
|
|
make_input(shape),
|
|
args=(make_input(shape, requires_grad=rhs_requires_grad),),
|
|
kwargs={"reduction": reduction},
|
|
)
|
|
|
|
|
|
additional_op_db.append(
|
|
OpInfo(
|
|
"nn.functional.mse_loss",
|
|
variant_test_name="functorch",
|
|
sample_inputs_func=sample_inputs_mse_loss,
|
|
supports_out=False,
|
|
supports_forward_ad=True,
|
|
supports_fwgrad_bwgrad=True,
|
|
dtypes=floating_types_and(torch.float16),
|
|
backward_dtypes=floating_types(),
|
|
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
|
|
backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
|
|
)
|
|
)
|
|
|
|
|
|
# TODO: upstream sample inputs to pytorch/pytorch.
|
|
# We are more comprehensive.
|
|
def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
|
|
# Short for "advanced index"
|
|
adv_idx = torch.LongTensor([[0, 1], [2, 3]])
|
|
S = 5
|
|
# self_dim, indices
|
|
test_args = [
|
|
(3, ([1, 2],)),
|
|
(3, (slice(0, 3),)),
|
|
(3, ((slice(0, 3), 1),)),
|
|
(3, (([0, 2, 3], [1, 3, 3], [0, 0, 2]),)),
|
|
(3, (([0, 0, 3], [1, 1, 3], [0, 0, 2]),)),
|
|
(3, ((slice(None), slice(None), [0, 3]),)),
|
|
(3, ((slice(None), [0, 3], slice(None)),)),
|
|
(3, (([0, 3], slice(None), slice(None)),)),
|
|
(3, (([0, 3], [1, 2], slice(None)),)),
|
|
(
|
|
3,
|
|
(
|
|
[
|
|
[0, 3],
|
|
],
|
|
),
|
|
),
|
|
(3, (([0, 3], slice(None)),)),
|
|
(3, (([0, 3], Ellipsis),)),
|
|
(3, (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),)),
|
|
(4, ((slice(None), adv_idx, adv_idx, slice(None)),)),
|
|
(4, ((slice(None), adv_idx, slice(None), adv_idx),)),
|
|
(4, ((adv_idx, slice(None), slice(None), adv_idx),)),
|
|
(4, ((slice(None), slice(None), adv_idx, adv_idx),)),
|
|
(4, ((Ellipsis, adv_idx, adv_idx),)),
|
|
(5, ((slice(None), slice(None), adv_idx, slice(None), adv_idx),)),
|
|
(5, ((slice(None), slice(None), adv_idx, adv_idx, slice(None)),)),
|
|
(5, ((slice(None), slice(None), adv_idx, None, adv_idx, slice(None)),)),
|
|
(6, ((slice(None), slice(None), slice(None), adv_idx, adv_idx),)),
|
|
(6, ((slice(None), slice(None), adv_idx, adv_idx, adv_idx),)),
|
|
(6, ((slice(None), slice(None), None, adv_idx, adv_idx, adv_idx),)),
|
|
]
|
|
|
|
def get_shape(dim):
|
|
return tuple(S + i for i in range(dim))
|
|
|
|
return tuple(
|
|
SampleInput(
|
|
make_tensor(
|
|
get_shape(self_dim),
|
|
device=device,
|
|
dtype=dtype,
|
|
low=None,
|
|
high=None,
|
|
requires_grad=requires_grad,
|
|
),
|
|
args=args,
|
|
)
|
|
for self_dim, args in test_args
|
|
)
|
|
|
|
|
|
# TODO: split PyTorch's __getitem__. The problem is we don't support indexing
|
|
# with masks with vmap.
|
|
additional_op_db.append(
|
|
OpInfo(
|
|
"__getitem__",
|
|
variant_test_name="functorch",
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
|
supports_out=False,
|
|
supports_inplace_autograd=False,
|
|
supports_scripting=False,
|
|
op=torch.Tensor.__getitem__,
|
|
assert_jit_shape_analysis=False, # TODO: support index.Tensor()
|
|
supports_forward_ad=True,
|
|
sample_inputs_func=sample_inputs_getitem,
|
|
)
|
|
)
|
|
|
|
|
|
# Turns out at::index_put is different from torch.index_put...
|
|
# TODO: figure out how to upstream this
|
|
def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs):
|
|
make_arg = partial(
|
|
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
|
|
)
|
|
inputs = []
|
|
adv_idx = torch.LongTensor([[0, 1], [2, 3]])
|
|
# self_shape, indices
|
|
additional = [
|
|
((5, 6, 7, 8), (None, adv_idx, adv_idx, None)),
|
|
((5, 6, 7, 8), (None, adv_idx, None, adv_idx)),
|
|
((5, 6, 7, 8), (adv_idx, None, None, adv_idx)),
|
|
((5, 6, 7, 8), (None, None, adv_idx, adv_idx)),
|
|
((5, 6, 7, 8, 9), (None, None, adv_idx, None, adv_idx)),
|
|
((5, 6, 7, 8, 9), (None, None, adv_idx, adv_idx, None)),
|
|
((5, 6, 7, 8, 9, 10), (None, None, None, adv_idx, adv_idx)),
|
|
((5, 6, 7, 8, 9, 10), (None, None, adv_idx, adv_idx, adv_idx)),
|
|
]
|
|
for self_shape, indices in additional:
|
|
for broadcast_value in [False, True]:
|
|
inp = make_arg(self_shape)
|
|
|
|
tmp_indices = tuple(
|
|
[slice(None) if idx is None else idx for idx in indices]
|
|
)
|
|
values_shape = inp[tmp_indices].shape
|
|
if broadcast_value:
|
|
values_shape = values_shape[3:]
|
|
values = make_arg(values_shape)
|
|
inputs.append(SampleInput(inp, args=(tuple(indices), values)))
|
|
return inputs
|
|
|
|
|
|
def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
|
|
make_arg = partial(
|
|
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
|
|
)
|
|
make_idx = partial(
|
|
make_tensor, dtype=torch.long, device=device, requires_grad=False
|
|
)
|
|
S = 5
|
|
inputs = []
|
|
for accumulate in [False, True]:
|
|
# putting vectors at indexed locations
|
|
inputs.append(
|
|
SampleInput(
|
|
make_arg((S, S)),
|
|
args=((make_idx((2,), low=0, high=4),), make_arg((2, S))),
|
|
kwargs=dict(accumulate=accumulate),
|
|
)
|
|
)
|
|
|
|
# putting multi-dim tensors at indexed locations
|
|
inputs.append(
|
|
SampleInput(
|
|
make_arg((S, S, 2)),
|
|
args=((make_idx((3,), low=0, high=4),), make_arg((3, S, 2))),
|
|
kwargs=dict(accumulate=accumulate),
|
|
)
|
|
)
|
|
|
|
# value with size `0` dim
|
|
inputs.append(
|
|
SampleInput(
|
|
make_arg((S, 0)),
|
|
args=((make_idx((3,), low=0, high=4),), make_arg((3, 0))),
|
|
kwargs=dict(accumulate=accumulate),
|
|
)
|
|
)
|
|
|
|
# scalar value
|
|
inputs.append(
|
|
SampleInput(
|
|
make_arg((S,)),
|
|
args=((make_idx((), low=0, high=S),), make_arg(())),
|
|
kwargs=dict(accumulate=accumulate),
|
|
)
|
|
)
|
|
|
|
# cuda and accumulate don't work well
|
|
# Reference: https://github.com/pytorch/pytorch/issues/72053
|
|
if not accumulate and device == "cuda":
|
|
# Broadcast `values`
|
|
inputs.append(
|
|
SampleInput(
|
|
make_arg((S, S)),
|
|
args=((make_idx((2,), low=0, high=S),), make_arg((S,))),
|
|
kwargs=dict(accumulate=accumulate),
|
|
)
|
|
)
|
|
|
|
return inputs
|
|
|
|
|
|
additional_op_db.append(
|
|
OpInfo(
|
|
"index_put",
|
|
variant_test_name="functorch",
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
|
supports_out=False,
|
|
sample_inputs_func=sample_inputs_index_put,
|
|
supports_forward_ad=True,
|
|
)
|
|
)
|
|
additional_op_db.append(
|
|
OpInfo(
|
|
"ops.aten.index_put",
|
|
variant_test_name="functorch",
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
|
supports_out=False,
|
|
sample_inputs_func=sample_inputs_aten_index_put,
|
|
supports_forward_ad=True,
|
|
)
|
|
)
|
|
|
|
|
|
def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs):
|
|
S = 3
|
|
make_arg = partial(
|
|
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
|
|
yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10))
|
|
yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10))
|
|
yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10))
|
|
yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10))
|
|
yield SampleInput(
|
|
make_arg((S,)),
|
|
args=(torch.randn(S, S, device=device) > 0, 10),
|
|
broadcasts_input=True,
|
|
)
|
|
|
|
|
|
additional_op_db.append(
|
|
OpInfo(
|
|
"masked_fill",
|
|
variant_test_name="functorch_Scalar_only",
|
|
dtypes=all_types_and_complex_and(
|
|
torch.bool, torch.half, torch.bfloat16, torch.chalf
|
|
),
|
|
sample_inputs_func=sample_inputs_masked_fill,
|
|
supports_forward_ad=True,
|
|
supports_fwgrad_bwgrad=True,
|
|
check_batched_forward_grad=False,
|
|
supports_out=False,
|
|
)
|
|
)
|
|
|
|
|
|
def sample_inputs_new_zeros_with_same_feature_meta(
|
|
op_info, device, dtype, requires_grad, **kwargs
|
|
):
|
|
make_arg = partial(
|
|
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
|
|
)
|
|
matrix = [
|
|
# tangent, base, num_tangent_bdims
|
|
([5], [2, 3], 0),
|
|
([2, 3], [2, 3], 0),
|
|
([5], [2], 0),
|
|
([1, 0, 2], [1, 2], 0),
|
|
([], [1, 2], 0),
|
|
([8, 7, 5], [2, 3, 11], 1),
|
|
([6, 7, 5], [2, 3, 4], 2),
|
|
([6, 4], [3], 2),
|
|
]
|
|
results = []
|
|
for tangent_shape, base_shape, num_tangent_bdims in matrix:
|
|
tangent = make_arg(tangent_shape)
|
|
base = make_arg(base_shape)
|
|
results.append(
|
|
SampleInput(
|
|
tangent,
|
|
args=(base,),
|
|
kwargs=dict(self_num_batch_dims=num_tangent_bdims),
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
additional_op_db.append(
|
|
OpInfo(
|
|
"ops.aten._new_zeros_with_same_feature_meta",
|
|
variant_test_name="functorchonly",
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
|
supports_out=False,
|
|
supports_autograd=False,
|
|
supports_forward_ad=False,
|
|
sample_inputs_func=sample_inputs_new_zeros_with_same_feature_meta,
|
|
)
|
|
)
|
|
|
|
|
|
def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs):
|
|
make_arg = partial(
|
|
make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
|
|
)
|
|
shapes = ((), (2, 3))
|
|
memory_format_options = [None, torch.contiguous_format]
|
|
for shape, memory_format in itertools.product(shapes, memory_format_options):
|
|
yield SampleInput(
|
|
make_arg(shape),
|
|
kwargs={"memory_format": memory_format} if memory_format else {},
|
|
)
|
|
|
|
|
|
additional_op_db.extend(
|
|
[
|
|
OpInfo(
|
|
"bfloat16",
|
|
op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
skips=(
|
|
# autograd tests don't handle operators that change dtype
|
|
DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
|
|
DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
DecorateInfo(
|
|
unittest.skip("Skipped!"), "TestNNCOpInfo", "test_nnc_correctness"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"bool",
|
|
op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
supports_autograd=False,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"byte",
|
|
op=lambda x, *args, **kwargs: x.byte(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
# The autograd test runner cannot handle functions that change dtype
|
|
supports_autograd=False,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"char",
|
|
op=lambda x, *args, **kwargs: x.char(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
# The autograd test runner cannot handle functions that change dtype
|
|
supports_autograd=False,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"double",
|
|
op=lambda x, *args, **kwargs: x.double(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
supports_forward_ad=True,
|
|
supports_fwgrad_bwgrad=True,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"float",
|
|
op=lambda x, *args, **kwargs: x.float(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
skips=(
|
|
# autograd tests don't handle operators that change dtype
|
|
DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
|
|
DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"half",
|
|
op=lambda x, *args, **kwargs: x.half(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
skips=(
|
|
# autograd tests don't handle operators that change dtype
|
|
DecorateInfo(unittest.expectedFailure, "TestFwdGradients"),
|
|
DecorateInfo(unittest.expectedFailure, "TestBwdGradients"),
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"int",
|
|
op=lambda x, *args, **kwargs: x.int(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
supports_autograd=False,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"long",
|
|
op=lambda x, *args, **kwargs: x.long(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
supports_autograd=False,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
OpInfo(
|
|
"short",
|
|
op=lambda x, *args, **kwargs: x.short(*args, **kwargs),
|
|
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
|
supports_out=False,
|
|
variant_test_name="functorch_no_channels_last",
|
|
sample_inputs_func=sample_inputs_conversion,
|
|
supports_autograd=False,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestNormalizeOperators",
|
|
"test_normalize_operator_exhaustive",
|
|
),
|
|
# RuntimeError: attribute lookup is not defined on builtin
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
),
|
|
),
|
|
]
|
|
)
|