mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144256 Approved by: https://github.com/aorenste
586 lines
19 KiB
Python
586 lines
19 KiB
Python
# mypy: allow-untyped-defs
|
|
import torch
|
|
import functools
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.opinfo.core import (
|
|
OpInfo,
|
|
SampleInput,
|
|
)
|
|
from torch.testing._internal.common_dtype import all_types_and
|
|
import numpy as np
|
|
from torch.testing._internal.autograd_function_db import (
|
|
sample_inputs_numpy_cube,
|
|
sample_inputs_numpy_mul,
|
|
sample_inputs_numpy_mul_scalar,
|
|
sample_inputs_numpy_sort,
|
|
sample_inputs_numpy_take,
|
|
)
|
|
from torch import Tensor
|
|
from torch.types import Number
|
|
from typing import * # noqa: F403
|
|
|
|
# Note: [custom op db]
|
|
#
|
|
# This is a collection of custom operator test cases written as OpInfos
|
|
# so they can easily be consumed by OpInfo-based tests to check if subsystems
|
|
# support them correctly.
|
|
|
|
def to_numpy(tensor):
|
|
return tensor.cpu().numpy()
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=())
|
|
def numpy_cube(x: Tensor) -> tuple[Tensor, Tensor]:
|
|
x_np = to_numpy(x)
|
|
dx = torch.tensor(3 * x_np ** 2, device=x.device)
|
|
return torch.tensor(x_np ** 3, device=x.device), dx
|
|
|
|
@numpy_cube.register_fake
|
|
def _(x):
|
|
return x.clone(), x.clone()
|
|
|
|
def numpy_cube_setup_context(ctx, inputs, output):
|
|
x, = inputs
|
|
_cube, dx = output
|
|
ctx.save_for_backward(x, dx)
|
|
|
|
def numpy_cube_backward(ctx, grad_out, grad_dx):
|
|
x, dx = ctx.saved_tensors
|
|
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
|
|
return grad_x
|
|
|
|
numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
|
|
|
|
def numpy_cube_vmap(info, in_dims, x):
|
|
result = numpy_cube(x)
|
|
return result, (in_dims[0], in_dims[0])
|
|
|
|
numpy_cube.register_vmap(numpy_cube_vmap)
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
|
|
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
|
|
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
|
|
|
|
@numpy_mul.register_fake
|
|
def _(x, y):
|
|
assert x.device == y.device
|
|
return (x * y).contiguous()
|
|
|
|
def numpy_mul_setup_context(ctx, inputs, output):
|
|
ctx.save_for_backward(*inputs)
|
|
|
|
def numpy_mul_backward(ctx, grad_out):
|
|
x, y = ctx.saved_tensors
|
|
grad_x = grad_out * y if ctx.needs_input_grad[0] else None
|
|
grad_y = grad_out * x if ctx.needs_input_grad[1] else None
|
|
return grad_x, grad_y
|
|
|
|
numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
|
|
|
|
def numpy_mul_vmap(info, in_dims, x, y):
|
|
x_bdim, y_bdim = in_dims
|
|
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
|
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
|
result = x * y
|
|
result = result.movedim(-1, 0)
|
|
return result, 0
|
|
|
|
numpy_mul.register_vmap(numpy_mul_vmap)
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=())
|
|
def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor:
|
|
return torch.tensor(to_numpy(x) * scalar, device=x.device)
|
|
|
|
@numpy_mul_scalar.register_fake
|
|
def _(x, *, scalar):
|
|
return (x * scalar).contiguous()
|
|
|
|
def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output):
|
|
ctx.scalar = keyword_only_inputs["scalar"]
|
|
|
|
def numpy_mul_scalar_backward(ctx, grad_out):
|
|
grad_x = grad_out * ctx.scalar
|
|
return grad_x
|
|
|
|
numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context)
|
|
|
|
def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar):
|
|
x_bdim, = in_dims
|
|
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
|
result = x * scalar
|
|
result = result.movedim(-1, 0)
|
|
return result, 0
|
|
|
|
numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap)
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
|
|
def numpy_sort(x: Tensor, dim: int) -> tuple[Tensor, Tensor, Tensor]:
|
|
device = x.device
|
|
x = to_numpy(x)
|
|
ind = np.argsort(x, axis=dim)
|
|
ind_inv = np.argsort(ind, axis=dim)
|
|
result = np.take_along_axis(x, ind, axis=dim)
|
|
return (
|
|
torch.tensor(result, device=device),
|
|
torch.tensor(ind, device=device),
|
|
torch.tensor(ind_inv, device=device),
|
|
)
|
|
|
|
@numpy_sort.register_fake
|
|
def _(x, dim):
|
|
return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
|
|
|
|
def numpy_sort_setup_context(ctx, inputs, output):
|
|
_out, ind, ind_inv = output
|
|
ctx.dim = inputs[1]
|
|
ctx.save_for_backward(ind, ind_inv)
|
|
ctx.mark_non_differentiable(ind, ind_inv)
|
|
|
|
def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
|
|
ind, ind_inv = ctx.saved_tensors
|
|
return numpy_take(grad_out, ind_inv, ind, ctx.dim), None
|
|
|
|
numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
|
|
|
|
def numpy_sort_vmap(info, in_dims, x, dim):
|
|
x_bdim, _ = in_dims
|
|
x = x.movedim(x_bdim, 0)
|
|
dim = dim if dim >= 0 else dim + x.dim() - 1
|
|
result = numpy_sort(x, dim + 1)
|
|
return result, (0, 0, 0)
|
|
|
|
numpy_sort.register_vmap(numpy_sort_vmap)
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
|
|
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
|
|
device = x.device
|
|
x = to_numpy(x)
|
|
ind = to_numpy(ind)
|
|
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
|
|
|
|
@numpy_take.register_fake
|
|
def _(x, ind, ind_inv, dim):
|
|
assert x.device == ind.device
|
|
assert x.device == ind_inv.device
|
|
assert ind.dtype == torch.long
|
|
assert ind_inv.dtype == torch.long
|
|
return torch.empty_like(x)
|
|
|
|
def numpy_take_setup_context(ctx, inputs, output):
|
|
_x, ind, ind_inv, dim = inputs
|
|
ctx.dim = dim
|
|
ctx.save_for_backward(ind, ind_inv)
|
|
|
|
def numpy_take_backward(ctx, grad_out):
|
|
ind, ind_inv = ctx.saved_tensors
|
|
grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim)
|
|
return grad_x, None, None, None
|
|
|
|
numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
|
|
|
|
def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim):
|
|
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
|
|
|
|
# wrap dim
|
|
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
|
|
dim = dim if dim >= 0 else dim + logical_dim
|
|
|
|
def expand_bdim(x, x_bdim):
|
|
if x_bdim is None:
|
|
return x.expand(info.batch_size, *x.shape)
|
|
return x.movedim(x_bdim, 0)
|
|
|
|
x = expand_bdim(x, x_bdim)
|
|
ind = expand_bdim(ind, ind_bdim)
|
|
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
|
|
|
|
return numpy_take(x, ind, ind_inv, dim + 1), 0
|
|
|
|
numpy_take.register_vmap(numpy_take_vmap)
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
|
|
def numpy_nonzero(x: Tensor) -> Tensor:
|
|
x_np = to_numpy(x)
|
|
res = np.stack(np.nonzero(x_np), axis=1)
|
|
if res.shape[0] <= 1:
|
|
raise RuntimeError("not supported")
|
|
return torch.tensor(res, device=x.device)
|
|
|
|
@numpy_nonzero.register_fake
|
|
def _(x):
|
|
ctx = torch._custom_op.impl.get_ctx()
|
|
i0 = ctx.create_unbacked_symint()
|
|
shape = [i0, x.dim()]
|
|
result = x.new_empty(shape, dtype=torch.long)
|
|
return result
|
|
|
|
def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
|
shape = 10
|
|
result = make_arg(shape, low=0.9, high=2)
|
|
mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long)
|
|
with torch.no_grad():
|
|
result *= mask
|
|
|
|
yield SampleInput(result, args=())
|
|
|
|
def numpy_nonzero_vmap(info, in_dims, x):
|
|
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
|
|
|
|
numpy_nonzero.register_vmap(numpy_nonzero_vmap)
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=())
|
|
def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
|
|
return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
|
|
|
|
@numpy_view_copy.register_fake
|
|
def _(x, shape) -> Tensor:
|
|
return x.clone().view(shape).clone()
|
|
|
|
def numpy_view_copy_setup_context(ctx, inputs, output) -> None:
|
|
ctx.x_shape = inputs[0].shape
|
|
|
|
def numpy_view_copy_backward(ctx, grad_out):
|
|
return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None
|
|
|
|
numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
|
|
|
|
def numpy_view_copy_vmap(info, in_dims, x, shape):
|
|
x_bdim, _ = in_dims
|
|
x = x.movedim(x_bdim, 0)
|
|
x_shape = x.shape[0]
|
|
batch_shape = (x_shape, *shape)
|
|
result = numpy_view_copy(x, batch_shape)
|
|
return result, 0
|
|
|
|
numpy_view_copy.register_vmap(numpy_view_copy_vmap)
|
|
|
|
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
|
result = make_arg(2, 3, 4, low=0.9, high=2)
|
|
yield SampleInput(result, args=([2, 12],))
|
|
|
|
@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=())
|
|
def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
|
|
assert len(xs) > 0
|
|
assert all(x.device == xs[0].device for x in xs)
|
|
assert all(x.dtype == xs[0].dtype for x in xs)
|
|
np_xs = [to_numpy(x) for x in xs]
|
|
np_out = np.concatenate(np_xs, axis=dim)
|
|
return torch.tensor(np_out, device=xs[0].device)
|
|
|
|
@numpy_cat.register_fake
|
|
def _(xs, dim):
|
|
assert len(xs) > 0
|
|
assert all(x.device == xs[0].device for x in xs)
|
|
assert all(x.dtype == xs[0].dtype for x in xs)
|
|
return torch.cat(xs, dim=dim)
|
|
|
|
def numpy_cat_setup_context(ctx, inputs, output):
|
|
xs, dim = inputs
|
|
ctx.dim_sizes = [x.shape[dim] for x in xs]
|
|
ctx.dim = dim
|
|
|
|
def numpy_cat_backward(ctx, grad_out):
|
|
dim_sizes = ctx.dim_sizes
|
|
dim = ctx.dim
|
|
|
|
splits = list(np.cumsum(dim_sizes)[:-1])
|
|
grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
|
|
return grad_xs, None
|
|
|
|
numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
|
|
|
|
def numpy_cat_vmap(info, in_dims, x, dim):
|
|
x_bdim, = in_dims
|
|
result = numpy_cat(x, dim)
|
|
return result, x_bdim
|
|
|
|
numpy_cat.register_vmap(numpy_cat_vmap)
|
|
|
|
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
|
r0 = make_arg(2, 3, 4, low=0.9, high=2)
|
|
r1 = make_arg(4, 3, 4, low=0.9, high=2)
|
|
r2 = make_arg(5, 3, 4, low=0.9, high=2)
|
|
yield SampleInput([r0, r1, r2], args=(0,))
|
|
|
|
@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=())
|
|
def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]:
|
|
x_np = to_numpy(x)
|
|
arrs = np.split(x_np, splits, axis=dim)
|
|
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs]
|
|
|
|
@numpy_split_copy.register_fake
|
|
def _(x, splits, dim):
|
|
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)]
|
|
|
|
def numpy_split_copy_setup_context(ctx, inputs, output):
|
|
_, _, dim = inputs
|
|
ctx.dim = dim
|
|
|
|
def numpy_split_copy_backward(ctx, grad_out):
|
|
result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim)
|
|
return result, None, None
|
|
|
|
numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
|
|
|
|
def numpy_split_copy_vmap(info, in_dims, x, splits, dim):
|
|
x_bdim, _ , _ = in_dims
|
|
x = x.movedim(x_bdim, 0)
|
|
result = numpy_split_copy(x, splits, dim + 1)
|
|
return result, 0
|
|
|
|
numpy_split_copy.register_vmap(numpy_split_copy_vmap)
|
|
|
|
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
|
x = make_arg(2, 9, low=0.9, high=2)
|
|
yield SampleInput(x, args=([1, 3, 6], 1))
|
|
|
|
@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=())
|
|
def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> tuple[List[Tensor], int]:
|
|
x_np = to_numpy(x)
|
|
arrs = np.split(x_np, splits, axis=dim)
|
|
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits)
|
|
|
|
@numpy_split_copy_with_int.register_fake
|
|
def _(x, splits, dim):
|
|
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits)
|
|
|
|
def numpy_split_copy_with_int_setup_context(ctx, inputs, output):
|
|
_, _, dim = inputs
|
|
ctx.dim = dim
|
|
|
|
def numpy_split_copy_with_int_backward(ctx, grad_out, _):
|
|
return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None
|
|
|
|
numpy_split_copy_with_int.register_autograd(
|
|
numpy_split_copy_with_int_backward,
|
|
setup_context=numpy_split_copy_with_int_setup_context)
|
|
|
|
def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim):
|
|
x_bdim, _ , _ = in_dims
|
|
x = x.movedim(x_bdim, 0)
|
|
result, len_split = numpy_split_copy_with_int(x, splits, dim + 1)
|
|
return (result, len_split), ([0 for _ in range(len(result))], None)
|
|
|
|
numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap)
|
|
|
|
@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
|
|
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
|
|
# Adapted from Ross Girshick's fast-rcnn implementation at
|
|
# https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
|
|
assert boxes.device == scores.device
|
|
device = boxes.device
|
|
|
|
boxes = to_numpy(boxes)
|
|
scores = to_numpy(scores)
|
|
|
|
N = boxes.shape[0]
|
|
assert boxes.shape == (N, 4)
|
|
assert scores.shape == (N,)
|
|
|
|
x1 = boxes[:, 0]
|
|
y1 = boxes[:, 1]
|
|
x2 = boxes[:, 2]
|
|
y2 = boxes[:, 3]
|
|
|
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
order = scores.argsort()[::-1]
|
|
|
|
keep = []
|
|
while order.size > 0:
|
|
i = order[0]
|
|
keep.append(i)
|
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
|
|
w = np.maximum(0.0, xx2 - xx1 + 1)
|
|
h = np.maximum(0.0, yy2 - yy1 + 1)
|
|
inter = w * h
|
|
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
|
|
|
inds = np.where(ovr <= iou_threshold)[0]
|
|
order = order[inds + 1]
|
|
|
|
result = torch.tensor(np.stack(keep), device=device)
|
|
# Needed for data-dependent condition :(
|
|
assert result.size(0) >= 2
|
|
return result
|
|
|
|
@numpy_nms.register_fake
|
|
def _(boxes, scores, iou_threshold):
|
|
assert boxes.device == scores.device
|
|
N = boxes.shape[0]
|
|
assert boxes.shape == (N, 4)
|
|
assert scores.shape == (N,)
|
|
|
|
ctx = torch._custom_op.impl.get_ctx()
|
|
i0 = ctx.create_unbacked_symint()
|
|
result = boxes.new_empty([i0], dtype=torch.int64)
|
|
return result
|
|
|
|
def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold):
|
|
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
|
|
|
|
numpy_nms.register_vmap(numpy_nms_vmap)
|
|
|
|
def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
|
|
make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
|
|
N = 64
|
|
xs = make_arg([N], low=0, high=28)
|
|
dx = make_arg([N], low=0, high=4)
|
|
ys = make_arg([N], low=0, high=28)
|
|
dy = make_arg([N], low=0, high=4)
|
|
boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad)
|
|
scores = make_arg([N], low=0, high=1, requires_grad=requires_grad)
|
|
iou_threshold = make_arg([], low=0, high=1).item()
|
|
|
|
yield SampleInput(boxes, args=(scores, iou_threshold))
|
|
|
|
custom_op_db = [
|
|
OpInfo(
|
|
'NumpyCubeCustomOp',
|
|
op=numpy_cube._opoverload,
|
|
sample_inputs_func=sample_inputs_numpy_cube,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpyMulCustomOp',
|
|
op=numpy_mul._opoverload,
|
|
sample_inputs_func=sample_inputs_numpy_mul,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpyMulScalarCustomOp',
|
|
op=numpy_mul_scalar._opoverload,
|
|
sample_inputs_func=sample_inputs_numpy_mul_scalar,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpySortCustomOp',
|
|
op=numpy_sort._opoverload,
|
|
sample_inputs_func=sample_inputs_numpy_sort,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpyTakeCustomOp',
|
|
op=numpy_take._opoverload,
|
|
sample_inputs_func=sample_inputs_numpy_take,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpyNonzeroCustomOp',
|
|
op=numpy_nonzero._opoverload,
|
|
sample_inputs_func=sample_inputs_numpy_nonzero,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_autograd=False,
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpyNMSCustomOp',
|
|
op=torch.ops._torch_testing.numpy_nms,
|
|
sample_inputs_func=sample_inputs_numpy_nms,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_autograd=False,
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpyViewCopyCustomOp',
|
|
op=torch.ops._torch_testing.numpy_view_copy,
|
|
sample_inputs_func=sample_inputs_numpy_view_copy,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_autograd=True,
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpyCatCustomOp',
|
|
op=torch.ops._torch_testing.numpy_cat,
|
|
sample_inputs_func=sample_inputs_numpy_cat,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_autograd=True,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpySplitCopyCustomOp',
|
|
op=torch.ops._torch_testing.numpy_split_copy,
|
|
sample_inputs_func=sample_inputs_numpy_split_copy,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
supports_autograd=True,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
supports_out=False,
|
|
),
|
|
OpInfo(
|
|
'NumpySplitCopyWithIntCustomOp',
|
|
op=torch.ops._torch_testing.numpy_split_copy_with_int,
|
|
sample_inputs_func=sample_inputs_numpy_split_copy,
|
|
dtypes=all_types_and(torch.bool, torch.half),
|
|
gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0],
|
|
supports_autograd=True,
|
|
check_batched_grad=False,
|
|
check_batched_gradgrad=False,
|
|
supports_out=False,
|
|
),
|
|
]
|
|
|
|
|
|
# ==============================================================
|
|
# some mechanical test cases
|
|
# ==============================================================
|
|
|
|
lib = torch.library.Library("_torch_testing", "FRAGMENT") # noqa: TOR901
|
|
|
|
lib.define("source0(Tensor x) -> Tensor")
|
|
|
|
@torch.library.register_fake("_torch_testing::source0", lib=lib)
|
|
def _(x):
|
|
return x.clone()
|
|
|
|
lib.define("source1(Tensor x) -> Tensor")
|
|
|
|
def source1_fake(x):
|
|
return x.clone()
|
|
|
|
torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib)
|
|
|
|
lib.define("source2(Tensor x) -> Tensor")
|
|
|
|
@torch.library.register_fake("_torch_testing::source2", lib=lib)
|
|
def _(x):
|
|
return x.clone()
|
|
|
|
lib.define("source3(Tensor x) -> Tensor")
|
|
|
|
def source3_fake(x):
|
|
return x.clone()
|
|
|
|
torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib)
|
|
|
|
|
|
@torch.library.custom_op("_torch_testing::source4", mutates_args=())
|
|
def source4(x: Tensor) -> Tensor:
|
|
return x.clone()
|
|
|
|
@source4.register_fake
|
|
def _(x):
|
|
return x.clone()
|
|
|
|
@torch.library.custom_op("_torch_testing::source5", mutates_args=())
|
|
def source5(x: Tensor) -> Tensor:
|
|
return x.clone()
|
|
|
|
def source5_fake(x):
|
|
return x.clone()
|
|
|
|
source5.register_fake(source5_fake)
|