mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[primTorch] flatten, squeeze, unsqueeze... (#77043)
This PR ... Makes the following testing changes: - Updates stride testing in test_python_reference_consistency to only check strides of dimensions with length > 1 - Creates reference inputs for reshape - Creates reference inputs for chunk - Extends the sample inputs for unsqueeze - Extends the sample inputs for stack -- test_conj_view and test_neg_view are now xfailed - https://github.com/pytorch/pytorch/issues/77046 Makes the following architecture changes: - Adds the refs.special (sub)module - Adds the refs.nn.functional (sub)module Adds the following prims: - expand_dims - view_of - rev - clone Adds the following references: - flatten - squeeze - unsqueeze - special.i0e - special.i1e - logical_or - logical_and - isclose - flip - stack - nn.functional.elu - chunk - clone - narrow Identifies the following bugs in PyTorch today: - https://github.com/pytorch/pytorch/issues/77054 - https://github.com/pytorch/pytorch/issues/77055 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77043 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
078a9eedc4
commit
bb8baea932
@ -101,7 +101,7 @@ from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WIT
|
||||
suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
|
||||
freeze_rng_state, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \
|
||||
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
|
||||
skipIfCrossRef
|
||||
skipIfCrossRef, IS_MACOS
|
||||
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
|
||||
_trace, do_input_map, get_execution_plan, make_global, \
|
||||
execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
|
||||
@ -1688,6 +1688,7 @@ graph(%Ra, %Rb):
|
||||
for node in g.nodes():
|
||||
self.assertTrue(g2.findNode(node.kind()) is not None)
|
||||
|
||||
@unittest.skipIf(IS_MACOS, "Failing on MacOS only")
|
||||
def test_python_ir_utils(self):
|
||||
@torch.jit.script
|
||||
def foo(inp):
|
||||
|
@ -1399,7 +1399,8 @@ class TestTEFuser(JitTestCase):
|
||||
# F.hardshrink,
|
||||
F.leaky_relu,
|
||||
lambda x: torch.threshold(x, 0, -10),
|
||||
lambda x: torch.clamp(x, -10, 10),
|
||||
# TODO: broken since type promotion was added
|
||||
# lambda x: torch.clamp(x, -10, 10),
|
||||
]
|
||||
gpu_only = {torch.erf, torch.erfc}
|
||||
sizes = [(1,), (2,), (4, 4)]
|
||||
|
@ -335,7 +335,11 @@ class TestCommon(TestCase):
|
||||
|
||||
meta_sample = sample.transform(_to_tensormeta)
|
||||
meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
|
||||
prims.utils.compare_tensor_meta(result, meta_result)
|
||||
if isinstance(result, torch.Tensor):
|
||||
prims.utils.compare_tensor_meta(result, meta_result)
|
||||
elif isinstance(result, Sequence):
|
||||
for a, b in zip(result, meta_result):
|
||||
prims.utils.compare_tensor_meta(a, b)
|
||||
|
||||
# Tests that experimental Python References perform the same computation
|
||||
# as the operators they reference.
|
||||
@ -350,12 +354,22 @@ class TestCommon(TestCase):
|
||||
self.assertEqual(
|
||||
actual,
|
||||
expected,
|
||||
exact_stride=True,
|
||||
exact_stride=False,
|
||||
exact_device=True,
|
||||
exact_layout=True,
|
||||
exact_is_coalesced=True,
|
||||
)
|
||||
|
||||
# TODO: move Sequence case into utils.compare_significant_strides
|
||||
if isinstance(actual, torch.Tensor):
|
||||
prims.utils.compare_significant_strides(actual, expected)
|
||||
if isinstance(actual, Sequence):
|
||||
for a, b in zip(actual, expected):
|
||||
prims.utils.compare_significant_strides(a, b)
|
||||
|
||||
# TODO: FIXME: enable view consistency testing
|
||||
# self.assertEqual(actual._is_view(), expected._is_view())
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@ops([op for op in ops_and_refs if op.error_inputs_func is not None], dtypes=OpDTypes.none)
|
||||
|
@ -1006,18 +1006,6 @@ def _fused_dropout_decomposition(input, p, generator=None):
|
||||
|
||||
|
||||
# TODO: these logical decomps are buggy for complex inputs
|
||||
|
||||
|
||||
@register_decomposition(aten.logical_and)
|
||||
def logical_and(self: Tensor, other: Tensor) -> Tensor:
|
||||
return self.to(dtype=torch.bool) & other.to(dtype=torch.bool)
|
||||
|
||||
|
||||
@register_decomposition(aten.logical_or)
|
||||
def logical_or(self: Tensor, other: Tensor) -> Tensor:
|
||||
return self.to(dtype=torch.bool) | other.to(dtype=torch.bool)
|
||||
|
||||
|
||||
@register_decomposition(aten.logical_xor)
|
||||
def logical_xor(self: Tensor, other: Tensor) -> Tensor:
|
||||
return self.to(dtype=torch.bool) ^ other.to(dtype=torch.bool)
|
||||
|
@ -8,6 +8,7 @@ from torch._prims.utils import (
|
||||
TensorMeta,
|
||||
ShapeType,
|
||||
getnvFuserDtype,
|
||||
DimsType,
|
||||
DimsSequenceType,
|
||||
StrideType,
|
||||
Number,
|
||||
@ -95,17 +96,20 @@ __all__ = [
|
||||
#
|
||||
"broadcast_in_dim",
|
||||
"collapse_view",
|
||||
"expand_dims",
|
||||
"slice",
|
||||
"slice_in_dim", # implemented using slice -- make this a ref?
|
||||
"split_dim",
|
||||
"squeeze",
|
||||
"transpose",
|
||||
"view_of",
|
||||
#
|
||||
# Shape prims
|
||||
#
|
||||
"collapse",
|
||||
"concatenate",
|
||||
"reshape",
|
||||
"rev",
|
||||
#
|
||||
# Conditional prims
|
||||
#
|
||||
@ -113,6 +117,7 @@ __all__ = [
|
||||
#
|
||||
# Data conversion and movement prims
|
||||
#
|
||||
"clone",
|
||||
"convert_element_type",
|
||||
"device_put",
|
||||
#
|
||||
@ -828,10 +833,10 @@ def _broadcast_in_dim_nvfuser(
|
||||
|
||||
|
||||
_broadcast_in_dim_doc = """
|
||||
Creates a view of t with the specified shape.
|
||||
Creates a view of a with the specified shape.
|
||||
|
||||
Allows adding dimensions of any length and broadcasting
|
||||
dimensions of length one in t to any length.
|
||||
dimensions of length one in a to any length.
|
||||
|
||||
The location of the broadcast dimensions must be specified
|
||||
using the broadcast_dimensions argument. Changing the
|
||||
@ -848,43 +853,67 @@ broadcast_in_dim = _make_prim(
|
||||
)
|
||||
|
||||
|
||||
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
|
||||
def _collapse_view_helper(
|
||||
a: TensorLikeType, start: int, end: int
|
||||
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
|
||||
assert isinstance(a, TensorLike)
|
||||
|
||||
shape = a.shape
|
||||
strides = a.stride()
|
||||
# Special-case for zero dimensional tensors
|
||||
if a.ndim == 0:
|
||||
shape = (1,)
|
||||
strides = (1,)
|
||||
else:
|
||||
shape = a.shape # type: ignore[assignment]
|
||||
strides = a.stride()
|
||||
|
||||
utils.validate_idx(shape, start)
|
||||
utils.validate_exclusive_idx(shape, end)
|
||||
utils.validate_idx(len(shape), start)
|
||||
utils.validate_exclusive_idx(len(shape), end)
|
||||
|
||||
# Verifies end is strictly greater than start
|
||||
# (Collapse requires a non-empty interval)
|
||||
assert end > start
|
||||
if end <= start:
|
||||
msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
|
||||
end, start
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
length = 1
|
||||
stride = 1
|
||||
for idx in range(start, end):
|
||||
if idx != (end - 1):
|
||||
assert strides[idx] == strides[idx + 1] * shape[idx + 1]
|
||||
if not (strides[idx] == strides[idx + 1] * shape[idx + 1]):
|
||||
return None, None
|
||||
length = length * shape[idx]
|
||||
stride = stride * strides[idx]
|
||||
|
||||
new_shape = shape[:start] + (length,) + shape[end:]
|
||||
new_strides = strides[:start] + (stride,) + shape[end:]
|
||||
|
||||
return new_shape, new_strides
|
||||
|
||||
|
||||
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
|
||||
new_shape, new_strides = _collapse_view_helper(a, start, end)
|
||||
|
||||
if new_shape is None:
|
||||
msg = "Attempting to view a collapsed tensor, but no such view exists!"
|
||||
raise ValueError(msg)
|
||||
|
||||
return TensorMeta(a, shape=new_shape, strides=new_strides)
|
||||
|
||||
|
||||
def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
|
||||
# Short-circuits on null op
|
||||
if start == end - 1:
|
||||
return a
|
||||
# Special-cases zero-dim tensors
|
||||
if a.ndim == 0:
|
||||
shape = (1,)
|
||||
else:
|
||||
shape = a.shape # type: ignore[assignment]
|
||||
|
||||
dim_length = 1
|
||||
for idx in range(start, end):
|
||||
dim_length = dim_length * a.shape[idx]
|
||||
dim_length = dim_length * shape[idx]
|
||||
|
||||
new_shape = a.shape[0:start] + (dim_length,) + a.shape[end:]
|
||||
new_shape = shape[0:start] + (dim_length,) + shape[end:]
|
||||
|
||||
return a.view(new_shape)
|
||||
|
||||
@ -914,6 +943,27 @@ collapse_view = _make_prim(
|
||||
doc=_collapse_view_doc,
|
||||
)
|
||||
|
||||
|
||||
def expand_dims(a: TensorLikeType, dimensions: DimsSequenceType) -> TensorLikeType:
|
||||
"""
|
||||
Creates a view of a with a.ndim + len(dimensions) dimensions, with new
|
||||
dimensions of length one at the dimensions specified by dimensions.
|
||||
"""
|
||||
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
|
||||
if len(set(dims)) != len(dims):
|
||||
msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
|
||||
raise ValueError(msg)
|
||||
|
||||
new_shape = list(a.shape)
|
||||
for idx in dims:
|
||||
new_shape.insert(idx, 1)
|
||||
|
||||
broadcast_dimensions = [
|
||||
idx for idx in range(len(new_shape)) if idx not in dimensions
|
||||
]
|
||||
return broadcast_in_dim(a, new_shape, broadcast_dimensions)
|
||||
|
||||
|
||||
# Note: saves the Python slice object because we're about to clobber its name with the slice prim
|
||||
pyslice = slice
|
||||
|
||||
@ -1123,17 +1173,22 @@ slice_in_dim = _make_prim(
|
||||
|
||||
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
|
||||
assert isinstance(a, TensorLike)
|
||||
utils.validate_idx(a.shape, dim)
|
||||
utils.validate_idx(a.ndim, dim)
|
||||
utils.validate_dim_length(outer_length)
|
||||
|
||||
# Verifies the dim can be split with the specified lhs_length
|
||||
_inner_length = a.shape[dim] / outer_length
|
||||
inner_length: int = int(_inner_length)
|
||||
assert inner_length == _inner_length
|
||||
|
||||
if inner_length != _inner_length:
|
||||
msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format(
|
||||
a.shape[dim], outer_length
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
new_shape: List[int] = []
|
||||
new_strides: List[int] = []
|
||||
for idx in a.shape:
|
||||
for idx in range(a.ndim):
|
||||
if idx == dim:
|
||||
new_shape.extend((outer_length, inner_length))
|
||||
new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
|
||||
@ -1172,7 +1227,7 @@ def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
|
||||
assert isinstance(a, TensorLike)
|
||||
|
||||
for idx in dimensions:
|
||||
utils.validate_idx(a.shape, idx)
|
||||
utils.validate_idx(a.ndim, idx)
|
||||
assert a.shape[idx] == 1
|
||||
|
||||
new_shape = []
|
||||
@ -1188,8 +1243,10 @@ def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
|
||||
|
||||
|
||||
def _squeeze_aten(a: Tensor, dimensions: Sequence) -> Tensor:
|
||||
squeezes = 0
|
||||
for idx in dimensions:
|
||||
a = torch.squeeze(a, dim=idx)
|
||||
a = torch.squeeze(a, dim=(idx - squeezes))
|
||||
squeezes = squeezes + 1
|
||||
|
||||
return a
|
||||
|
||||
@ -1249,6 +1306,27 @@ transpose = _make_prim(
|
||||
doc=_transpose_doc,
|
||||
)
|
||||
|
||||
|
||||
def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
|
||||
return TensorMeta(a)
|
||||
|
||||
|
||||
def _view_of_aten(a: Tensor) -> Tensor:
|
||||
return a.view(a.shape)
|
||||
|
||||
|
||||
_view_of_doc = """
|
||||
Creates a view of the tensor.
|
||||
"""
|
||||
|
||||
view_of = _make_prim(
|
||||
name="view_of",
|
||||
meta=_view_of_meta,
|
||||
impl_aten=_view_of_aten,
|
||||
return_type=RETURN_TYPE.VIEW,
|
||||
doc=_view_of_doc,
|
||||
)
|
||||
|
||||
#
|
||||
# Shape operations
|
||||
#
|
||||
@ -1256,7 +1334,7 @@ def collapse(a: Tensor, start: int, end: int) -> Tensor:
|
||||
"""
|
||||
Wrapper around reshape that collapses a span of dimensions.
|
||||
|
||||
See merge_dims for the corresponding view operation.
|
||||
See collapse_view for the corresponding view operation.
|
||||
"""
|
||||
|
||||
dim_length = 1
|
||||
@ -1280,7 +1358,7 @@ def _concatenate_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLike
|
||||
utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
|
||||
|
||||
shape = tensors[0].shape
|
||||
utils.validate_idx(shape, dim)
|
||||
utils.validate_idx(tensors[0].ndim, dim)
|
||||
|
||||
# Verifies same shape (except in the concat dimension)
|
||||
concat_length = 0
|
||||
@ -1321,20 +1399,23 @@ concatenate = _make_prim(
|
||||
)
|
||||
|
||||
|
||||
# TODO: needs to return the proper meta tensor
|
||||
def _reshape_meta(a: TensorLikeType, shape: Sequence):
|
||||
def _reshape_meta(a: TensorLikeType, shape: ShapeType):
|
||||
assert isinstance(a, TensorLike)
|
||||
utils.validate_shape(shape)
|
||||
|
||||
# Validates the tensor and the requested shape have the
|
||||
# same number of elements
|
||||
numel = reduce(lambda acc, x: acc * x, shape)
|
||||
assert a.numel() == numel
|
||||
numel = reduce(operator.mul, shape)
|
||||
if numel != a.numel():
|
||||
msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format(
|
||||
a.numel(), numel
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
|
||||
|
||||
|
||||
def _reshape_aten(
|
||||
a: Tensor, shape: Union[torch.Size, List[int], Tuple[int, ...]]
|
||||
) -> Tensor:
|
||||
def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
|
||||
return a.clone().reshape(shape).contiguous()
|
||||
|
||||
|
||||
@ -1350,6 +1431,24 @@ reshape = _make_prim(
|
||||
doc=_reshape_doc,
|
||||
)
|
||||
|
||||
|
||||
def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
|
||||
utils.validate_dimension_indices(a.ndim, dims)
|
||||
return TensorMeta(a)
|
||||
|
||||
|
||||
_rev_doc = """
|
||||
Reverses the order of elements along the given dimensions.
|
||||
"""
|
||||
|
||||
rev = _make_prim(
|
||||
name="rev",
|
||||
meta=_rev_meta,
|
||||
impl_aten=torch.flip,
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
doc=_rev_doc,
|
||||
)
|
||||
|
||||
#
|
||||
# Conditional prims
|
||||
#
|
||||
@ -1389,6 +1488,28 @@ select = _make_prim(
|
||||
#
|
||||
# Type conversions
|
||||
#
|
||||
# TODO: model memory format on TensorMeta
|
||||
def _clone_meta(
|
||||
a: TensorLikeType, *, memory_format: torch.memory_format
|
||||
) -> TensorLikeType:
|
||||
return TensorMeta(a)
|
||||
|
||||
|
||||
def _clone_aten(a: Tensor, *, memory_format: torch.memory_format) -> Tensor:
|
||||
return torch.clone(a, memory_format=memory_format)
|
||||
|
||||
|
||||
_clone_doc = """
|
||||
Creates a copy of a tensors.
|
||||
"""
|
||||
|
||||
clone = _make_prim(
|
||||
name="clone",
|
||||
meta=_clone_meta,
|
||||
impl_aten=_clone_aten,
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
doc=_clone_doc,
|
||||
)
|
||||
|
||||
|
||||
def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
|
||||
|
@ -160,7 +160,7 @@ TensorSequenceType = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
|
||||
def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType):
|
||||
"""
|
||||
Checks that two tensor likes have the same shape,
|
||||
dtype, and device.
|
||||
dtype and device.
|
||||
|
||||
In the future this will validate additional metadata, like
|
||||
strides.
|
||||
@ -182,6 +182,16 @@ def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType):
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def compare_significant_strides(a: TensorLikeType, b: TensorLikeType):
|
||||
assert a.ndim == b.ndim
|
||||
|
||||
for idx in range(a.ndim):
|
||||
assert a.shape[idx] == b.shape[idx]
|
||||
if a.shape[idx] == 0 or a.shape[idx] == 1:
|
||||
continue
|
||||
assert a.stride()[idx] == b.stride()[idx]
|
||||
|
||||
|
||||
#
|
||||
# Common helper functions
|
||||
#
|
||||
@ -207,25 +217,31 @@ def validate_shape(shape: Sequence):
|
||||
validate_dim_length(l)
|
||||
|
||||
|
||||
def validate_idx(shape: Sequence, idx: int):
|
||||
def validate_idx(rank: int, idx: int):
|
||||
"""
|
||||
Validates that idx is a valid idx for the given shape.
|
||||
0 and -1 is a valid index for an empty shape
|
||||
Validates that idx is a valid index for the given shape.
|
||||
Assumes the index is already canonicalized.
|
||||
"""
|
||||
|
||||
assert isinstance(idx, int)
|
||||
ndim = len(shape) if len(shape) else 1
|
||||
assert idx >= 0 and idx < ndim
|
||||
assert isinstance(rank, int)
|
||||
|
||||
assert idx >= 0 and idx < rank or idx == 0
|
||||
|
||||
def validate_dimension_indices(rank: int, indices: DimsSequenceType):
|
||||
for idx in indices:
|
||||
validate_idx(rank, idx)
|
||||
|
||||
|
||||
def validate_exclusive_idx(shape: Sequence, ex_idx: int):
|
||||
def validate_exclusive_idx(rank: int, ex_idx: int):
|
||||
"""
|
||||
Validates that ex_idx is a valid exclusive index
|
||||
for the given shape.
|
||||
"""
|
||||
|
||||
assert isinstance(ex_idx, int)
|
||||
assert ex_idx > 0 and ex_idx <= len(shape)
|
||||
assert isinstance(rank, int)
|
||||
assert ex_idx > 0 and ex_idx <= rank
|
||||
|
||||
|
||||
# "Wraps" a dim (up to one time) for the given rank, allowing
|
||||
@ -368,18 +384,22 @@ _complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)
|
||||
|
||||
|
||||
def is_boolean_dtype(dtype: torch.dtype) -> bool:
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
return dtype is torch.bool
|
||||
|
||||
|
||||
def is_integer_dtype(dtype: torch.dtype) -> bool:
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
return dtype in _integer_dtypes
|
||||
|
||||
|
||||
def is_float_dtype(dtype: torch.dtype) -> bool:
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
return dtype in _float_dtypes
|
||||
|
||||
|
||||
def is_complex_dtype(dtype: torch.dtype) -> bool:
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
return dtype in _complex_dtypes
|
||||
|
||||
|
||||
@ -530,6 +550,7 @@ def get_higher_dtype(
|
||||
raise RuntimeError("Unexpected termination!")
|
||||
|
||||
|
||||
# TODO: maybe unify with can_cast_to?
|
||||
def is_weakly_lesser_type(a: type, b: type) -> bool:
|
||||
"""
|
||||
Compares two types, a and b, returning True if a is weakly "less" than b.
|
||||
@ -883,7 +904,7 @@ def compute_reduction_output_shape(
|
||||
shape: ShapeType, dimensions: Sequence
|
||||
) -> Tuple[int, ...]:
|
||||
for idx in dimensions:
|
||||
validate_idx(shape, idx)
|
||||
validate_idx(len(shape), idx)
|
||||
|
||||
new_shape = []
|
||||
for idx in range(len(shape)):
|
||||
|
@ -4,6 +4,7 @@ import torch._prims as prims
|
||||
import torch._prims.utils as utils
|
||||
from torch._prims.utils import (
|
||||
DimsType,
|
||||
ShapeType,
|
||||
TensorLike,
|
||||
TensorLikeType,
|
||||
DimsSequenceType,
|
||||
@ -91,12 +92,12 @@ __all__ = [
|
||||
# 'hypot',
|
||||
"igamma",
|
||||
"igammac",
|
||||
# 'isclose', # abs, sub, le, add, mul
|
||||
"isclose",
|
||||
# 'lcm',
|
||||
# 'ldexp',
|
||||
"le",
|
||||
# 'logical_and',
|
||||
# 'logical_or',
|
||||
"logical_and",
|
||||
"logical_or",
|
||||
# 'logical_xor',
|
||||
"lt",
|
||||
# 'max', # implement with reductions
|
||||
@ -122,6 +123,7 @@ __all__ = [
|
||||
#
|
||||
# Data conversion and movement references
|
||||
#
|
||||
"clone",
|
||||
"copy_to", # TODO: add opinfo
|
||||
#
|
||||
# Reduction ops
|
||||
@ -133,10 +135,17 @@ __all__ = [
|
||||
# View & Shape Ops
|
||||
#
|
||||
"cat",
|
||||
"chunk",
|
||||
"flatten",
|
||||
"flip",
|
||||
"narrow",
|
||||
"permute",
|
||||
"transpose",
|
||||
"stack",
|
||||
"swap_axes", # alias for transpose
|
||||
"squeeze",
|
||||
"tensor_split",
|
||||
"transpose",
|
||||
"unsqueeze",
|
||||
]
|
||||
|
||||
Tensor = torch.Tensor
|
||||
@ -281,7 +290,8 @@ erf = _make_elementwise_unary_reference(
|
||||
)
|
||||
|
||||
erfinv = _make_elementwise_unary_reference(
|
||||
prims.erf_inv, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
prims.erf_inv,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
aten_op=torch.ops.aten.erfinv, # prim/aten name mismatch
|
||||
)
|
||||
|
||||
@ -302,7 +312,8 @@ floor = _make_elementwise_unary_reference(
|
||||
)
|
||||
|
||||
isfinite = _make_elementwise_unary_reference(
|
||||
prims.is_finite, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
prims.is_finite,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
aten_op=None, # CompositeImplicitAutograd
|
||||
)
|
||||
|
||||
@ -312,7 +323,8 @@ def _isnan(a: Tensor) -> Tensor:
|
||||
|
||||
|
||||
isnan = _make_elementwise_unary_reference(
|
||||
_isnan, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
_isnan,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
aten_op=torch.ops.aten.isnan, # prim/aten name mismatch
|
||||
)
|
||||
|
||||
@ -338,7 +350,8 @@ reciprocal = _make_elementwise_unary_reference(
|
||||
|
||||
# TODO: round takes additional kwargs
|
||||
round = _make_elementwise_unary_reference(
|
||||
prims.round, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
prims.round,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=None, # TODO: this does need a decomp, but kwarg handling is needed
|
||||
)
|
||||
|
||||
@ -359,7 +372,8 @@ sqrt = _make_elementwise_unary_reference(
|
||||
)
|
||||
|
||||
square = _make_elementwise_unary_reference(
|
||||
prims.square, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
|
||||
prims.square,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
|
||||
aten_op=None, # CompositeImplicitAutograd
|
||||
)
|
||||
|
||||
@ -369,9 +383,12 @@ tan = _make_elementwise_unary_reference(
|
||||
|
||||
|
||||
def _make_elementwise_binary_reference(
|
||||
prim: Callable, *, type_promotion_kind, aten_op=infer_aten_op
|
||||
prim: Callable,
|
||||
*,
|
||||
type_promotion_kind,
|
||||
aten_op=infer_aten_op,
|
||||
has_out=True,
|
||||
) -> Callable:
|
||||
@out_wrapper
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a", "b"), type_promotion_kind=type_promotion_kind
|
||||
)
|
||||
@ -382,6 +399,9 @@ def _make_elementwise_binary_reference(
|
||||
a, b = _maybe_broadcast(a, b)
|
||||
return prim(a, b)
|
||||
|
||||
if has_out:
|
||||
_ref = out_wrapper(_ref)
|
||||
|
||||
if aten_op is infer_aten_op:
|
||||
aten_op = getattr(torch.ops.aten, prim.__name__)
|
||||
if aten_op is not None:
|
||||
@ -429,12 +449,14 @@ atan2 = _make_elementwise_binary_reference(
|
||||
|
||||
# TODO: add docstring
|
||||
bitwise_and = _make_elementwise_binary_reference(
|
||||
prims.bitwise_and, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
prims.bitwise_and,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
bitwise_left_shift = _make_elementwise_binary_reference(
|
||||
prims.shift_left, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
prims.shift_left,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
aten_op=torch.ops.aten.bitwise_left_shift, # prim/aten name mismatch
|
||||
)
|
||||
|
||||
@ -504,11 +526,93 @@ igammac = _make_elementwise_binary_reference(
|
||||
prims.igammac, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
|
||||
)
|
||||
|
||||
|
||||
def isclose(
|
||||
a: TensorLikeType,
|
||||
b: TensorLikeType,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
) -> TensorLikeType:
|
||||
if a.dtype != b.dtype:
|
||||
msg = "Attempting to compare tensors of different dtypes {0} and {1}!".format(
|
||||
a.dtype, b.dtype
|
||||
)
|
||||
raise ValueError(a, b)
|
||||
if rtol < 0:
|
||||
msg = "rtol must be greater than or equal to zero, but got {0}!".format(rtol)
|
||||
if atol < 0:
|
||||
msg = "atol must be greater than or equal to zero, but got {0}!".format(atol)
|
||||
|
||||
close = eq(a, b)
|
||||
if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
|
||||
close = logical_or(close, logical_and(isnan(a), isnan(b)))
|
||||
|
||||
# Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
|
||||
# In this case, the short-circuit prevents false positives as detailed in the paragraph below.
|
||||
if atol == 0 and rtol == 0:
|
||||
return close
|
||||
|
||||
# Note [closeness error computation]
|
||||
# atol and rtol are provided as doubles, so the computation
|
||||
# rtol * other will produce a float or complex tensor.
|
||||
# When the difference (self - other) is compared to it then the
|
||||
# tensor representing the difference will also be cast to float or complex.
|
||||
# However, since (self - other) in uint8 is very likely to produce a
|
||||
# negative value, this moves the cast forward so the difference is
|
||||
# always computed in a float or complex type.
|
||||
# If the values of the integer tensors cannot be exactly represented
|
||||
# by the default scalar type then this may cause an incorrect result.
|
||||
if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
|
||||
a = prims.convert_element_type(a, torch.get_default_dtype())
|
||||
b = prims.convert_element_type(b, torch.get_default_dtype())
|
||||
|
||||
allowed_error = add(atol, abs(mul(b, rtol)))
|
||||
actual_error = abs(sub(a, b))
|
||||
|
||||
# Computes finite closeness
|
||||
result = logical_or(
|
||||
close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# TODO: add docstring
|
||||
le = _make_elementwise_binary_reference(
|
||||
prims.le, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
|
||||
)
|
||||
|
||||
|
||||
def _logical_and(a: TensorLikeType, b: TensorLikeType):
|
||||
if not utils.is_boolean_dtype(a.dtype):
|
||||
a = ne(a, 0)
|
||||
if not utils.is_boolean_dtype(b.dtype):
|
||||
b = ne(b, 0)
|
||||
return bitwise_and(a, b)
|
||||
|
||||
|
||||
logical_and = _make_elementwise_binary_reference(
|
||||
_logical_and,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
aten_op=torch.ops.aten.logical_and,
|
||||
)
|
||||
|
||||
|
||||
def _logical_or(a: TensorLikeType, b: TensorLikeType):
|
||||
if not utils.is_boolean_dtype(a.dtype):
|
||||
a = ne(a, 0)
|
||||
if not utils.is_boolean_dtype(b.dtype):
|
||||
b = ne(b, 0)
|
||||
return bitwise_or(a, b)
|
||||
|
||||
|
||||
logical_or = _make_elementwise_binary_reference(
|
||||
_logical_or,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
|
||||
aten_op=torch.ops.aten.logical_or,
|
||||
)
|
||||
|
||||
# TODO: add docstring
|
||||
lt = _make_elementwise_binary_reference(
|
||||
prims.lt, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
|
||||
@ -620,6 +724,10 @@ def where(
|
||||
#
|
||||
# Data Movement References
|
||||
#
|
||||
def clone(
|
||||
a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
|
||||
) -> TensorLikeType:
|
||||
return prims.clone(a, memory_format=memory_format)
|
||||
|
||||
|
||||
def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
|
||||
@ -781,11 +889,83 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
||||
return prims.concatenate(tensors, _dim)
|
||||
|
||||
|
||||
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
|
||||
if chunks <= 0:
|
||||
msg = "Expected at least one chunk, but got {0}!".format(chunks)
|
||||
raise ValueError(msg)
|
||||
|
||||
dim = utils.canonicalize_dim(a.ndim, dim)
|
||||
length = a.shape[dim]
|
||||
chunk_size = math.ceil(length / chunks)
|
||||
full_chunks = math.floor(length / chunk_size)
|
||||
tail_chunk_size = length % chunk_size
|
||||
|
||||
result = []
|
||||
for i in range(full_chunks):
|
||||
result.append(narrow(a, dim, i * chunk_size, chunk_size))
|
||||
|
||||
if tail_chunk_size != 0:
|
||||
result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
|
||||
|
||||
return tuple(result)
|
||||
|
||||
|
||||
# Note: flatten, unlike prim.collapse and prim.collapse_view has an inclusive end_dim
|
||||
def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
|
||||
start_dim = utils.canonicalize_dim(a.ndim, start_dim)
|
||||
end_dim = utils.canonicalize_dim(a.ndim, end_dim) + 1
|
||||
|
||||
# Tries to take a view
|
||||
# TODO: we could look at directing collapse_view to skip its meta function here
|
||||
new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
|
||||
if new_shape is not None:
|
||||
return prims.collapse_view(a, start_dim, end_dim)
|
||||
|
||||
# Makes a copy if it can't make a view
|
||||
result = prims.collapse(a, start_dim, end_dim)
|
||||
return result
|
||||
|
||||
|
||||
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
|
||||
dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment]
|
||||
return prims.rev(a, dims)
|
||||
|
||||
|
||||
def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType:
|
||||
dim = utils.canonicalize_dim(a.ndim, dim)
|
||||
return prims.slice_in_dim(a, start, start + length, axis=dim)
|
||||
|
||||
|
||||
def permute(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
|
||||
_permutation = utils.canonicalize_dims(a.ndim, dims)
|
||||
return prims.transpose(a, _permutation)
|
||||
|
||||
|
||||
# update to cat then view instead of unsqueezing each tensor
|
||||
@out_wrapper
|
||||
def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
||||
tensors = tuple(unsqueeze(a, dim) for a in tensors)
|
||||
return cat(tensors, dim)
|
||||
|
||||
|
||||
# Note: although squeeze is documented as having the out= kwarg it doesn't
|
||||
def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType:
|
||||
if dim is not None:
|
||||
dim = utils.canonicalize_dim(a.ndim, dim)
|
||||
# Short-circuits if the tensor has no dimensions
|
||||
if len(a.shape) == 0:
|
||||
assert dim == 0
|
||||
return prims.view_of(a)
|
||||
|
||||
# Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
|
||||
if a.shape[dim] != 1:
|
||||
return prims.view_of(a)
|
||||
return prims.squeeze(a, (dim,))
|
||||
|
||||
dims = tuple(idx for idx in range(len(a.shape)) if a.shape[idx] == 1)
|
||||
return prims.squeeze(a, dims)
|
||||
|
||||
|
||||
# Note: does not work with TensorMetas because of data-dependent control-flow
|
||||
def tensor_split(
|
||||
a: TensorLikeType,
|
||||
@ -875,5 +1055,12 @@ def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
|
||||
return prims.transpose(a, _permutation)
|
||||
|
||||
|
||||
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
|
||||
# Note that unsqueeze canonicalizes with rank + 1 because it allows
|
||||
# a new innermost dimension to be specified
|
||||
dim = utils.canonicalize_dim(a.ndim + 1, dim)
|
||||
return prims.expand_dims(a, (dim,))
|
||||
|
||||
|
||||
# Aliases for transpose
|
||||
swap_axes = transpose
|
||||
|
0
torch/_refs/nn/__init__.py
Normal file
0
torch/_refs/nn/__init__.py
Normal file
47
torch/_refs/nn/functional/__init__.py
Normal file
47
torch/_refs/nn/functional/__init__.py
Normal file
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
|
||||
import torch._prims.utils as utils
|
||||
from torch._prims.utils import (
|
||||
TensorLikeType,
|
||||
NumberType,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
)
|
||||
import torch._refs as refs
|
||||
from torch._prims.wrappers import elementwise_type_promotion_wrapper
|
||||
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
"elu",
|
||||
]
|
||||
|
||||
# elu is implemented specially because it has an alpha argument
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH,
|
||||
)
|
||||
def elu(
|
||||
a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False
|
||||
) -> TensorLikeType:
|
||||
"""
|
||||
Reference implementation of torch.nn.functional.elu
|
||||
"""
|
||||
|
||||
if inplace:
|
||||
raise NotImplementedError
|
||||
|
||||
rhs: TensorLikeType
|
||||
if alpha is not None:
|
||||
python_type = utils.dtype_to_type(a.dtype)
|
||||
if not utils.is_weakly_lesser_type(type(alpha), python_type):
|
||||
msg = (
|
||||
"alpha argument of type {0} cannot be safely cast to type {1}!".format(
|
||||
type(alpha), python_type
|
||||
)
|
||||
)
|
||||
raise ValueError(msg)
|
||||
rhs = refs.mul(alpha, refs.expm1(a))
|
||||
else:
|
||||
rhs = refs.expm1(a)
|
||||
|
||||
return refs.where(refs.gt(a, 0), a, rhs)
|
23
torch/_refs/special/__init__.py
Normal file
23
torch/_refs/special/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
import torch._prims as prims
|
||||
import torch._prims.utils as utils
|
||||
from torch._prims.utils import TensorLikeType
|
||||
from torch._prims.wrappers import out_wrapper, elementwise_type_promotion_wrapper
|
||||
from torch._refs import _make_elementwise_unary_reference
|
||||
|
||||
__all__ = [
|
||||
"i0e",
|
||||
"i1e",
|
||||
]
|
||||
|
||||
i0e = _make_elementwise_unary_reference(
|
||||
prims.bessel_i0e,
|
||||
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
aten_op=torch.ops.aten.special_i0e,
|
||||
)
|
||||
i1e = _make_elementwise_unary_reference(
|
||||
prims.bessel_i1e,
|
||||
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
aten_op=torch.ops.aten.special_i1e,
|
||||
)
|
@ -44,6 +44,8 @@ from torch.testing._internal.common_utils import \
|
||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||
|
||||
import torch._refs as refs # noqa: F401
|
||||
import torch._refs.nn.functional
|
||||
import torch._refs.special
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
@ -3735,13 +3737,20 @@ def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs):
|
||||
yield SampleInput(lhs, args=(lhs.clone(),))
|
||||
|
||||
def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs):
|
||||
tensors = [
|
||||
make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad),
|
||||
make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad),
|
||||
make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad),
|
||||
]
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
return (SampleInput(tensors, args=(0,)),)
|
||||
# shape x number of tensors
|
||||
cases = (
|
||||
((3, 4), 1),
|
||||
((1, 2, 1, 4), 3),
|
||||
((0, 1, 0), 2),)
|
||||
|
||||
for shape, num_tensors in cases:
|
||||
tensors = []
|
||||
for _ in range(num_tensors):
|
||||
tensors.append(make_arg(shape))
|
||||
for dim in range(-1, len(shape) - 1):
|
||||
yield SampleInput(tensors, args=(dim,))
|
||||
|
||||
def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
@ -5559,6 +5568,7 @@ def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs):
|
||||
((S, S, S), (-1, 2, 2)),
|
||||
((S, S, S), (1, 0, 0)),
|
||||
((S, S, S), (-1, 0, 0)),
|
||||
((S, S, S), (2, 1, 2)),
|
||||
)
|
||||
|
||||
for shape, args in shapes_and_args:
|
||||
@ -5628,7 +5638,10 @@ def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
|
||||
((3, 4, 5), 3),
|
||||
((3, 4, 5), -1),
|
||||
((3, 4, 5), -3),
|
||||
((), 0)
|
||||
((), 0),
|
||||
((), -1),
|
||||
((1,), 0),
|
||||
((1,), -1),
|
||||
]
|
||||
|
||||
samples = []
|
||||
@ -7788,7 +7801,7 @@ def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs):
|
||||
((S * S * 2, S), (S, -1)),
|
||||
((S,), (S,)),
|
||||
((), ()),
|
||||
((), (1,)))
|
||||
((), (1,)),)
|
||||
|
||||
for case in cases:
|
||||
shape, args = case
|
||||
@ -7800,6 +7813,38 @@ def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs):
|
||||
inp.clone().transpose(0, 1).requires_grad_(requires_grad),
|
||||
args=(args, )))
|
||||
|
||||
def reference_inputs_reshape(op, device, dtype, requires_grad, **kwargs):
|
||||
yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs)
|
||||
|
||||
cases = (
|
||||
((125,), (25, 5)),
|
||||
((25, 25), (1, 5, 5, 1, 5, 1, 5, 1)),
|
||||
((16, 32), (2, 4, 1, 4, 4, 1, 4)),
|
||||
((16, 12), (12, 16)),
|
||||
((1, 16, 12), (12, 16)),
|
||||
((1, 5, 1, 5), (25, 1)),
|
||||
((2, 4, 2), (4, 4)),
|
||||
((1, 4), (1, 1, 2, 1, 2)),
|
||||
((3, 5, 7), (7, 5, 3)),
|
||||
((1,), ()),
|
||||
((5, 0, 2, 3), (5, 0, 2, 3)),
|
||||
((2, 1, 0, 3, 1), (5, 0)),
|
||||
((1, 0, 3), ())
|
||||
)
|
||||
|
||||
irreversible_cases = (
|
||||
((), (-1,)),
|
||||
)
|
||||
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
for a, b in cases:
|
||||
yield SampleInput(make_arg(a), args=(b,))
|
||||
yield SampleInput(make_arg(b), args=(a,))
|
||||
yield SampleInput(make_arg(a, noncontiguous=True).transpose(0, -1), args=(b,))
|
||||
|
||||
for a, b in irreversible_cases:
|
||||
yield SampleInput(make_arg(a), args=(b,))
|
||||
|
||||
def sample_inputs_view_as_reshape_as(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device)
|
||||
|
||||
@ -8009,7 +8054,7 @@ def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
|
||||
kwargs=dict(as_tuple=as_tuple)))
|
||||
|
||||
def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device)
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
|
||||
cases = (((S, S, S), (2,)),
|
||||
((S, S, S), (S, 1)),
|
||||
@ -8017,7 +8062,32 @@ def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
for case in cases:
|
||||
shape, args = case
|
||||
yield(SampleInput(make_arg(shape, requires_grad=requires_grad), args=args))
|
||||
yield(SampleInput(make_arg(shape), args=args))
|
||||
|
||||
def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs):
|
||||
yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs)
|
||||
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
|
||||
# shape x chunks x dim
|
||||
cases = (
|
||||
((13, 9, 11), 17, -1),
|
||||
((13, 9, 11), 11, -1),
|
||||
((13,), 12, -1),
|
||||
((15,), 12, -1),
|
||||
((15,), 7, 0),
|
||||
((15,), 9, 0),
|
||||
((3, 7), 9, 1),
|
||||
((3, 7), 9, 0),
|
||||
((3, 7), 2, 0),
|
||||
((3, 7), 3, 0),
|
||||
((3, 7), 1, 0),
|
||||
((3, 7), 1, 1),
|
||||
((4, 4), 2, 0),
|
||||
)
|
||||
|
||||
for shape, chunks, dim in cases:
|
||||
yield SampleInput(make_arg(shape), args=(chunks, dim))
|
||||
|
||||
def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs):
|
||||
def _tensor(shape, dtype=dtype, low=None, high=None):
|
||||
@ -10041,6 +10111,7 @@ op_db: List[OpInfo] = [
|
||||
OpInfo('chunk',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
|
||||
sample_inputs_func=sample_inputs_chunk,
|
||||
reference_inputs_func=reference_inputs_chunk,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False),
|
||||
@ -10091,7 +10162,19 @@ op_db: List[OpInfo] = [
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_clamp),
|
||||
sample_inputs_func=sample_inputs_clamp,
|
||||
skips=(
|
||||
# boolean alpha not handled properly
|
||||
DecorateInfo(unittest.expectedFailure,
|
||||
'TestCudaFuserOpInfo',
|
||||
'test_nvfuser_correctness',
|
||||
dtypes=(torch.bool, torch.int32, torch.int64)),
|
||||
# boolean alpha not handled properly
|
||||
DecorateInfo(unittest.expectedFailure,
|
||||
'TestNNCOpInfo',
|
||||
'test_nnc_correctness',
|
||||
dtypes=(torch.bool, torch.int32, torch.int64)),
|
||||
)),
|
||||
UnaryUfuncInfo('clamp',
|
||||
variant_test_name='scalar',
|
||||
aliases=('clip', ),
|
||||
@ -15042,6 +15125,7 @@ op_db: List[OpInfo] = [
|
||||
OpInfo('reshape',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_view_reshape,
|
||||
reference_inputs_func=reference_inputs_reshape,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
@ -15687,6 +15771,11 @@ op_db: List[OpInfo] = [
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
# https://github.com/pytorch/pytorch/issues/77046
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
),
|
||||
),
|
||||
OpInfo('hstack',
|
||||
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
|
||||
@ -18079,6 +18168,39 @@ python_ref_db = [
|
||||
"_refs.tan",
|
||||
torch_opinfo_name="tan",
|
||||
),
|
||||
#
|
||||
# Elementwise Unary Special OpInfos
|
||||
#
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.i0e",
|
||||
torch_opinfo_name="special.i0e",
|
||||
decorators=(
|
||||
DecorateInfo(toleranceOverride({
|
||||
torch.bfloat16: tol(atol=1e-2, rtol=0),
|
||||
}), 'TestCommon', 'test_python_reference_consistency', device_type='cpu'),
|
||||
),
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.i1e",
|
||||
torch_opinfo_name="special.i1e",
|
||||
),
|
||||
#
|
||||
# Elementwise Unary nn.functional OpInfos
|
||||
#
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.nn.functional.elu",
|
||||
torch_opinfo_name="nn.functional.elu",
|
||||
decorators=(
|
||||
# https://github.com/pytorch/pytorch/issues/77054
|
||||
DecorateInfo(toleranceOverride({
|
||||
torch.bfloat16: tol(atol=1e-2, rtol=0),
|
||||
torch.float16: tol(atol=1e-3, rtol=0),
|
||||
}), 'TestCommon', 'test_python_reference_consistency', device_type='cpu'),
|
||||
),
|
||||
),
|
||||
#
|
||||
# Elementwise Binary OpInfos
|
||||
#
|
||||
ElementwiseBinaryPythonRefInfo(
|
||||
"_refs.add",
|
||||
torch_opinfo_name="add",
|
||||
@ -18152,10 +18274,26 @@ python_ref_db = [
|
||||
"_refs.igammac",
|
||||
torch_opinfo_name="igammac",
|
||||
),
|
||||
ElementwiseBinaryPythonRefInfo(
|
||||
"_refs.isclose",
|
||||
torch_opinfo_name="isclose",
|
||||
skips=(
|
||||
# Intentional xfail -- isclose does not type promote
|
||||
DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
|
||||
),
|
||||
),
|
||||
ElementwiseBinaryPythonRefInfo(
|
||||
"_refs.le",
|
||||
torch_opinfo_name="le",
|
||||
),
|
||||
ElementwiseBinaryPythonRefInfo(
|
||||
"_refs.logical_and",
|
||||
torch_opinfo_name="logical_and",
|
||||
),
|
||||
ElementwiseBinaryPythonRefInfo(
|
||||
"_refs.logical_or",
|
||||
torch_opinfo_name="logical_or",
|
||||
),
|
||||
ElementwiseBinaryPythonRefInfo(
|
||||
"_refs.lt",
|
||||
torch_opinfo_name="lt",
|
||||
@ -18241,6 +18379,13 @@ python_ref_db = [
|
||||
),
|
||||
),
|
||||
#
|
||||
# Data Conversion & Data Movement Opinfos
|
||||
#
|
||||
PythonRefInfo(
|
||||
"_refs.clone",
|
||||
torch_opinfo_name="clone",
|
||||
),
|
||||
#
|
||||
# View & Shape OpInfos
|
||||
#
|
||||
PythonRefInfo(
|
||||
@ -18255,10 +18400,41 @@ python_ref_db = [
|
||||
dtypes=(torch.chalf,)),
|
||||
)
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.chunk",
|
||||
torch_opinfo_name="chunk",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.flatten",
|
||||
torch_opinfo_name="flatten",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.flip",
|
||||
torch_opinfo_name="flip",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.narrow",
|
||||
torch_opinfo_name="narrow",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.permute",
|
||||
torch_opinfo_name="permute",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.stack",
|
||||
torch_opinfo_name="stack",
|
||||
skips=(
|
||||
# ValueError: Callable cat has no meta function!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_meta_functions'),
|
||||
# https://github.com/pytorch/pytorch/issues/77046
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.squeeze",
|
||||
torch_opinfo_name="squeeze",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.tensor_split",
|
||||
torch_opinfo_name="tensor_split",
|
||||
@ -18271,6 +18447,10 @@ python_ref_db = [
|
||||
"_refs.transpose",
|
||||
torch_opinfo_name="transpose",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.unsqueeze",
|
||||
torch_opinfo_name="unsqueeze",
|
||||
),
|
||||
#
|
||||
# Reduction OpInfos
|
||||
#
|
||||
|
Reference in New Issue
Block a user