mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This takes the strategy described in https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# It is essentially https://github.com/pytorch/pytorch/pull/95222 but squashed and with changes that are unnecessary given that we assume nonzero returns > 1. What's in the PR: * nonzero now supports meta propagation. When `capture_dynamic_output_shape_ops`, it will return a tensor with an unbacked SymInt representing the size in question. * The unbacked SymInt is UNSOUNDLY assumed to be not equal to 0/1. We will still error if you guard otherwise. * PrimTorch pointwise operators are updated to use empty_permuted, to avoid guarding on unbacked SymInt from empty_strided (tested in `test_dynamic_pointwise_scalar`) * Convolution is updated to skip backend selection if batch is unbacked, to avoid guarding on unbacked SymInt (tested in `test_unbacked_batch_resnet`) * I kept the helper utilities like `definitely_true` for working with possibly unbacked SymInts. They're not used right now but maybe someone will find them useful. * Added `constrain_unify` to let you specify two unbacked SymInts must have the same value Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/95387 Approved by: https://github.com/voznesenskym
2730 lines
84 KiB
Python
2730 lines
84 KiB
Python
import math
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch._prims_common as utils
|
|
from torch import Tensor
|
|
from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
|
|
from torch._ops import OpOverload
|
|
from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
|
|
from torch._prims_common import (
|
|
check,
|
|
corresponding_complex_dtype,
|
|
corresponding_real_dtype,
|
|
elementwise_dtypes,
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
|
IntLike,
|
|
make_contiguous_strides_for,
|
|
)
|
|
|
|
from torch._prims_common.wrappers import out_wrapper
|
|
from torch._refs import _broadcast_shapes
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
|
|
|
|
|
|
def register_meta(op):
|
|
def wrapper(fn):
|
|
def register(op):
|
|
_add_op_to_registry(meta_table, op, fn)
|
|
|
|
tree_map(register, op)
|
|
return fn
|
|
|
|
return wrapper
|
|
|
|
|
|
def toRealValueType(dtype):
|
|
from_complex = {
|
|
torch.complex32: torch.half,
|
|
torch.cfloat: torch.float,
|
|
torch.cdouble: torch.double,
|
|
}
|
|
return from_complex.get(dtype, dtype)
|
|
|
|
|
|
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
|
|
@out_wrapper()
|
|
def meta_fft_c2c(self, dim, normalization, forward):
|
|
assert self.dtype.is_complex
|
|
return self.new_empty(self.size())
|
|
|
|
|
|
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
|
|
@out_wrapper()
|
|
def meta_fft_r2c(self, dim, normalization, onesided):
|
|
assert self.dtype.is_floating_point
|
|
output_sizes = list(self.size())
|
|
|
|
if onesided:
|
|
last_dim = dim[-1]
|
|
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
|
|
output_sizes[last_dim] = last_dim_halfsize
|
|
|
|
return self.new_empty(
|
|
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
|
|
)
|
|
|
|
|
|
@register_meta(aten.randperm.generator_out)
|
|
def meta_randperm(n, *, generator=None, out):
|
|
assert out.ndim == 1 and out.size(0) == n
|
|
return out
|
|
|
|
|
|
@register_meta(aten.randint.default)
|
|
def meta_randint(
|
|
high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
|
|
):
|
|
return torch.empty(
|
|
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta(aten.randint.low)
|
|
def meta_randint_low(
|
|
low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
|
|
):
|
|
return torch.empty(
|
|
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta(aten.rand.default)
|
|
def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
|
|
return torch.empty(
|
|
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
|
|
@out_wrapper()
|
|
def meta_fft_c2r(self, dim, normalization, lastdim):
|
|
assert self.dtype.is_complex
|
|
output_sizes = list(self.size())
|
|
output_sizes[dim[-1]] = lastdim
|
|
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
|
|
|
|
|
|
@register_meta(aten.copy_.default)
|
|
def meta_copy_(self, src, non_blocking=False):
|
|
return self
|
|
|
|
|
|
def inferUnsqueezeGeometry(tensor, dim):
|
|
result_sizes = list(tensor.size())
|
|
result_strides = list(tensor.stride())
|
|
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
|
|
result_sizes.insert(dim, 1)
|
|
result_strides.insert(dim, new_stride)
|
|
return result_sizes, result_strides
|
|
|
|
|
|
@register_meta(aten.unsqueeze_.default)
|
|
def meta_unsqueeze_(self, dim):
|
|
dim = maybe_wrap_dim(dim, self.dim() + 1)
|
|
g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
|
|
self.as_strided_(g_sizes, g_strides)
|
|
return self
|
|
|
|
|
|
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
|
|
@register_meta(aten.index_select.default)
|
|
def meta_index_select(self, dim, index):
|
|
result_size = list(self.size())
|
|
if self.dim() > 0:
|
|
result_size[dim] = index.numel()
|
|
return self.new_empty(result_size)
|
|
|
|
|
|
@register_meta(aten.index_select.out)
|
|
def meta_index_select_out(self, dim, index, out):
|
|
torch._resize_output_(out, self.size(), self.device)
|
|
return out.copy_(torch.index_select(self, dim, index))
|
|
|
|
|
|
@register_meta([aten.max.default, aten.max.unary_out])
|
|
@out_wrapper()
|
|
def meta_max(self):
|
|
return self.new_empty(())
|
|
|
|
|
|
@register_meta(aten.max.dim)
|
|
def meta_max_dim(self, dim, keepdim=False):
|
|
dim = utils.reduction_dims(self.shape, (dim,))
|
|
output_shape = _compute_reduction_shape(self, dim, keepdim)
|
|
return (
|
|
self.new_empty(output_shape),
|
|
self.new_empty(output_shape, dtype=torch.long),
|
|
)
|
|
|
|
|
|
@register_meta([aten.min.default])
|
|
def meta_min(self):
|
|
return self.new_empty(())
|
|
|
|
|
|
@register_meta(aten.angle.default)
|
|
def meta_angle(self):
|
|
if self.is_complex():
|
|
result_dtype = corresponding_real_dtype(self.dtype)
|
|
else:
|
|
_, result_dtype = elementwise_dtypes(
|
|
self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
|
|
)
|
|
return torch.empty_like(self, dtype=result_dtype)
|
|
|
|
|
|
@register_meta(aten.angle.out)
|
|
def meta_angle_out(self, out):
|
|
torch._resize_output_(out, self.size(), self.device)
|
|
return out.copy_(torch.angle(self))
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
|
def squareCheckInputs(self: Tensor, f_name: str):
|
|
assert (
|
|
self.dim() >= 2
|
|
), f"{f_name}: The input tensor must have at least 2 dimensions."
|
|
assert self.size(-1) == self.size(
|
|
-2
|
|
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
|
def checkFloatingOrComplex(
|
|
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
|
|
):
|
|
dtype = t.dtype
|
|
check(
|
|
t.is_floating_point() or t.is_complex(),
|
|
lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}",
|
|
)
|
|
if allow_low_precision_dtypes:
|
|
check(
|
|
dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
|
|
lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}",
|
|
)
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
|
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
|
|
check(
|
|
A.dim() >= 2,
|
|
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
|
|
)
|
|
|
|
|
|
def checkUplo(uplo: str):
|
|
uplo_uppercase = uplo.upper()
|
|
assert (
|
|
len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L"
|
|
), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}"
|
|
|
|
|
|
# @register_meta(aten.linalg_eigh.default)
|
|
def meta_linalg_eigh(self, uplo="L"):
|
|
squareCheckInputs(self, "linalg_eigh")
|
|
checkUplo(uplo)
|
|
real_dtype = toRealValueType(self.dtype)
|
|
assert self.dim() >= 2
|
|
values = self.new_empty(self.shape, dtype=real_dtype)
|
|
values.transpose_(-2, -1)
|
|
vectors = self.new_empty(self.shape[:-1])
|
|
return (values, vectors)
|
|
|
|
|
|
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
|
|
@register_meta(aten.linalg_cholesky_ex.default)
|
|
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
|
|
squareCheckInputs(A, "linalg.cholesky")
|
|
checkFloatingOrComplex(A, "linalg.cholesky")
|
|
|
|
A_shape = A.shape
|
|
ndim = len(A_shape)
|
|
|
|
# L
|
|
L_strides = make_contiguous_strides_for(A_shape, False)
|
|
L = A.new_empty(A_shape)
|
|
L.as_strided_(A_shape, L_strides)
|
|
|
|
# infos
|
|
infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
|
|
return L, infos
|
|
|
|
|
|
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
|
|
@register_meta(aten.linalg_inv_ex.default)
|
|
def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
|
|
squareCheckInputs(A, "linalg.inv_ex")
|
|
checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
|
|
|
|
L = A.new_empty(A.shape)
|
|
L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
|
|
|
|
infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
|
|
return L, infos
|
|
|
|
|
|
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
|
|
# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
|
|
@register_meta(aten._linalg_svd.default)
|
|
def _linalg_svd_meta(
|
|
A: Tensor, full_matrices: bool = False, compute_uv: bool = True, driver: str = None
|
|
):
|
|
checkIsMatrix(A, "linalg.svd")
|
|
checkFloatingOrComplex(A, "linalg.svd")
|
|
|
|
batch_dims = list(A.shape[:-2])
|
|
m = A.shape[-2]
|
|
n = A.shape[-1]
|
|
k = min(m, n)
|
|
|
|
if compute_uv:
|
|
U_shape = batch_dims + [m, m if full_matrices else k]
|
|
U = A.new_empty(U_shape)
|
|
U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
|
|
|
|
V_shape = batch_dims + [n if full_matrices else k, n]
|
|
V = A.new_empty(V_shape)
|
|
# TODO: need to distinguish cuSOLVER case? (see original code)
|
|
V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=False))
|
|
else:
|
|
# doesn't matter
|
|
U = A.new_empty([0])
|
|
V = A.new_empty([0])
|
|
|
|
# S is always real, even when A is complex.
|
|
S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
|
|
return U, S, V
|
|
|
|
|
|
# From aten/src/ATen/native/LinearAlgebra.cpp
|
|
@register_meta(aten._linalg_det.default)
|
|
def _linalg_det_meta(A):
|
|
squareCheckInputs(A, "linalg.det")
|
|
checkFloatingOrComplex(A, "linalg.det")
|
|
|
|
det = A.new_empty(A.shape[:-2])
|
|
|
|
LU = A.new_empty(A.shape)
|
|
LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
|
|
|
|
pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
|
|
return det, LU, pivots
|
|
|
|
|
|
# From aten/src/ATen/native/ReflectionPad.cpp
|
|
@register_meta(
|
|
[aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default]
|
|
)
|
|
def meta_pad2d_backward(grad_output, self, padding):
|
|
dim_w = 2
|
|
dim_h = 1
|
|
dim_plane = 0
|
|
nbatch = 1
|
|
|
|
self_shape = self.shape
|
|
if self.dim() == 4:
|
|
nbatch = self_shape[0]
|
|
dim_w += 1
|
|
dim_h += 1
|
|
dim_plane += 1
|
|
|
|
pad_l = padding[0]
|
|
pad_r = padding[1]
|
|
pad_t = padding[2]
|
|
pad_b = padding[3]
|
|
|
|
nplane = self_shape[dim_plane]
|
|
input_h = self_shape[dim_h]
|
|
input_w = self_shape[dim_w]
|
|
output_h = input_h + pad_t + pad_b
|
|
output_w = input_w + pad_l + pad_r
|
|
|
|
check(
|
|
output_w == grad_output.shape[dim_w],
|
|
lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
|
|
)
|
|
check(
|
|
output_h == grad_output.shape[dim_h],
|
|
lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
|
|
)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.reflection_pad2d.default)
|
|
def meta_pad2d(self, padding):
|
|
valid_dims = self.size(1) != 0 and self.size(2) != 0
|
|
check(
|
|
(self.ndim == 3 and valid_dims)
|
|
or (self.ndim == 4 and valid_dims and self.size(3) != 0),
|
|
lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
|
|
)
|
|
if self.ndim == 4:
|
|
nbatch, nplane, input_h, input_w = self.shape
|
|
else:
|
|
nbatch = 1
|
|
nplane, input_h, input_w = self.shape
|
|
|
|
pad_l, pad_r, pad_t, pad_b = padding
|
|
|
|
output_h = input_h + pad_t + pad_b
|
|
output_w = input_w + pad_l + pad_r
|
|
|
|
if self.ndim == 3:
|
|
return self.new_empty((nplane, output_h, output_w))
|
|
else:
|
|
return self.new_empty((nbatch, nplane, output_h, output_w))
|
|
|
|
|
|
@register_meta([aten.bernoulli.default, aten.bernoulli.out])
|
|
@out_wrapper()
|
|
def meta_bernoulli(self, *, generator=None):
|
|
# https://github.com/pytorch/pytorch/issues/88612
|
|
return torch.empty_like(self).contiguous()
|
|
|
|
|
|
@register_meta(aten.bernoulli_.float)
|
|
def meta_bernoulli_(self, p=0.5, generator=None):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.bernoulli.p)
|
|
def meta_bernoulli_p(self, p=0.5, generator=None):
|
|
# https://github.com/pytorch/pytorch/issues/88612
|
|
return torch.empty_like(self).contiguous()
|
|
|
|
|
|
@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
|
|
def meta__fused_moving_avg_obs_fq_helper(
|
|
self,
|
|
observer_on,
|
|
fake_quant_on,
|
|
running_min,
|
|
running_max,
|
|
scale,
|
|
zero_point,
|
|
averaging_const,
|
|
quant_min,
|
|
quant_max,
|
|
ch_axis,
|
|
per_row_fake_quant=False,
|
|
symmetric_quant=False,
|
|
):
|
|
check(
|
|
ch_axis < self.dim(),
|
|
lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
|
|
)
|
|
mask = torch.empty_like(self, dtype=torch.bool)
|
|
return (torch.empty_like(self), mask)
|
|
|
|
|
|
def dot_check(self, other):
|
|
check(
|
|
self.dim() == 1 and other.dim() == 1,
|
|
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
|
|
)
|
|
|
|
|
|
@register_meta(aten.dot.default)
|
|
def meta_dot(self, tensor):
|
|
dot_check(self, tensor)
|
|
return self.new_empty(())
|
|
|
|
|
|
@register_meta([aten.mm.default])
|
|
def meta_mm(a, b):
|
|
check(a.dim() == 2, lambda: "a must be 2D")
|
|
check(b.dim() == 2, lambda: "b must be 2D")
|
|
N, M1 = a.shape
|
|
M2, P = b.shape
|
|
check(M1 == M2, lambda: "a and b must have same reduction dim")
|
|
return a.new_empty(N, P)
|
|
|
|
|
|
def _compute_reduction_shape(self, dims, keepdim):
|
|
if keepdim:
|
|
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
|
|
|
|
return utils.compute_reduction_output_shape(self.shape, dims)
|
|
|
|
|
|
# FakeTensors (meta tensors with a device) will report device as meta
|
|
# when running meta kernels. Here, access the "fake device" of FakeTensor if it
|
|
# exists so meta kernels which have diverge per device will be more
|
|
# accurate when run with FakeTensors
|
|
def device_hint(tensor) -> "str":
|
|
if isinstance(tensor, torch._subclasses.FakeTensor):
|
|
return tensor.fake_device.type
|
|
else:
|
|
return "cuda" # default to cuda
|
|
|
|
|
|
def calc_conv_nd_return_shape(
|
|
input_tensor: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
stride: Union[List[int], int],
|
|
padding: Union[List[int], int],
|
|
dilation: Union[List[int], int],
|
|
is_transposed: bool,
|
|
groups: int,
|
|
output_padding: Optional[Union[List[int], int]] = None,
|
|
):
|
|
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
|
"""
|
|
Formula to apply to calculate the length of some dimension of the output
|
|
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
|
|
|
Args:
|
|
ln: length of the dimension
|
|
p: padding in that dim
|
|
d: dilation in that dim
|
|
k: kernel size in that dim
|
|
s: stride in that dim
|
|
Returns:
|
|
The output length
|
|
"""
|
|
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
|
|
|
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
|
"""
|
|
Formula to apply to calculate the length of some dimension of the output
|
|
if transposed convolution is used.
|
|
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
|
|
|
Args:
|
|
ln: length of the dimension
|
|
p: padding in that dim
|
|
d: dilation in that dim
|
|
k: kernel size in that dim
|
|
s: stride in that dim
|
|
op: output padding in that dim
|
|
|
|
Returns:
|
|
The output length
|
|
"""
|
|
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
|
|
|
kernel_size = weight.shape[2:]
|
|
dims = input_tensor.shape[2:]
|
|
if is_transposed:
|
|
out_channels = groups * weight.shape[1]
|
|
else:
|
|
out_channels = weight.shape[0]
|
|
if weight.shape[1] * groups != input_tensor.shape[1]:
|
|
raise RuntimeError("Invalid channel dimensions")
|
|
|
|
ret_shape = [input_tensor.shape[0], out_channels]
|
|
if isinstance(stride, IntLike):
|
|
stride = [stride] * len(dims)
|
|
elif len(stride) == 1:
|
|
stride = [stride[0]] * len(dims)
|
|
|
|
if isinstance(padding, IntLike):
|
|
padding = [padding] * len(dims)
|
|
elif len(padding) == 1:
|
|
padding = [padding[0]] * len(dims)
|
|
|
|
if isinstance(dilation, IntLike):
|
|
dilation = [dilation] * len(dims)
|
|
elif len(dilation) == 1:
|
|
dilation = [dilation[0]] * len(dims)
|
|
|
|
output_padding_list: Optional[List[int]] = None
|
|
if output_padding:
|
|
if isinstance(output_padding, IntLike):
|
|
output_padding_list = [output_padding] * len(dims)
|
|
elif len(output_padding) == 1:
|
|
output_padding_list = [output_padding[0]] * len(dims)
|
|
else:
|
|
output_padding_list = output_padding
|
|
|
|
for i in range(len(dims)):
|
|
# If output_padding is present, we are dealing with a transposed convolution
|
|
if output_padding_list:
|
|
ret_shape.append(
|
|
_formula_transposed(
|
|
dims[i],
|
|
padding[i],
|
|
dilation[i],
|
|
kernel_size[i],
|
|
stride[i],
|
|
output_padding_list[i],
|
|
)
|
|
)
|
|
else:
|
|
ret_shape.append(
|
|
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
|
|
)
|
|
|
|
return ret_shape
|
|
|
|
|
|
def is_channels_last(ten):
|
|
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
|
|
|
|
|
|
@register_meta(aten.convolution.default)
|
|
def meta_conv(
|
|
input_tensor: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
is_transposed: bool,
|
|
output_padding: List[int],
|
|
groups: int,
|
|
):
|
|
def pick_memory_format():
|
|
if device_hint(input_tensor) == "cuda":
|
|
if is_channels_last(input_tensor) or is_channels_last(weight):
|
|
return torch.channels_last
|
|
else:
|
|
if is_channels_last(input_tensor):
|
|
return torch.channels_last
|
|
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
|
return torch.contiguous_format
|
|
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
|
return torch.preserve_format
|
|
|
|
shape_out = calc_conv_nd_return_shape(
|
|
input_tensor,
|
|
weight,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
is_transposed,
|
|
groups,
|
|
output_padding if is_transposed else None,
|
|
)
|
|
|
|
out = input_tensor.new_empty(shape_out)
|
|
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
|
|
return out
|
|
|
|
|
|
if torch._C.has_mkldnn:
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
|
|
"mkldnn", "IMPL", "Meta"
|
|
)
|
|
|
|
def pick_mkldnn_conv_memory_format(input_tensor, weight):
|
|
if weight.is_mkldnn:
|
|
return torch.channels_last
|
|
if is_channels_last(input_tensor) or is_channels_last(weight):
|
|
return torch.channels_last
|
|
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
|
return torch.contiguous_format
|
|
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
|
return torch.preserve_format
|
|
|
|
@register_meta(torch.ops.mkldnn._convolution_pointwise.default)
|
|
def meta_mkldnn_convolution_default(
|
|
input_tensor,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
shape_out = calc_conv_nd_return_shape(
|
|
input_tensor, weight, stride, padding, dilation, False, groups, []
|
|
)
|
|
out = input_tensor.new_empty(shape_out)
|
|
out_memory_format = torch.channels_last
|
|
out = out.to(memory_format=out_memory_format) # type: ignore[call-overload]
|
|
return out
|
|
|
|
@register_meta(torch.ops.mkldnn._convolution_pointwise.binary)
|
|
def meta_mkldnn_convolution_binary(
|
|
input_tensor,
|
|
other,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
):
|
|
out = input_tensor.new_empty(other.size())
|
|
out = out.to(memory_format=torch.channels_last) # type: ignore[call-overload]
|
|
return out
|
|
|
|
@register_meta(torch.ops.mkldnn._convolution_pointwise_.binary)
|
|
def meta_mkldnn_convolution_binary_inplace(
|
|
input_tensor,
|
|
other,
|
|
weight,
|
|
bias,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
binary_attr,
|
|
alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
):
|
|
return other
|
|
|
|
@register_meta(torch.ops.mkldnn._linear_pointwise.default)
|
|
def meta_linear_pointwise_default(
|
|
input_tensor, weight, bias, attr, scalars, algorithm
|
|
):
|
|
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
|
|
|
|
@register_meta(torch.ops.mkldnn._linear_pointwise.binary)
|
|
def meta_linear_pointwise_binary(input_tensor, other, weight, bias, attr):
|
|
out = input_tensor.new_empty(other.size())
|
|
return out
|
|
|
|
if torch._C.has_mkl:
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
|
|
"mkl", "IMPL", "Meta"
|
|
)
|
|
|
|
@register_meta(torch.ops.mkl._mkl_linear)
|
|
def meta_mkl_linear(
|
|
input_tensor,
|
|
packed_weight,
|
|
orig_weight,
|
|
bias,
|
|
batch_size,
|
|
):
|
|
return input_tensor.new_empty(
|
|
(*input_tensor.shape[:-1], orig_weight.shape[0])
|
|
)
|
|
|
|
|
|
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
|
|
def check_dim_size(tensor, dim, dim_size, size):
|
|
check(
|
|
tensor.dim() == dim and tensor.shape[dim_size] == size,
|
|
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
|
|
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
|
|
)
|
|
|
|
|
|
@register_meta(aten.avg_pool2d.default)
|
|
def meta_avg_pool2d(
|
|
input,
|
|
kernel_size,
|
|
stride=(),
|
|
padding=(0,),
|
|
ceil_mode=False,
|
|
count_include_pad=True,
|
|
divisor_override=None,
|
|
):
|
|
def unpack(name, val):
|
|
check(
|
|
len(val) in [1, 2],
|
|
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
|
|
)
|
|
H = val[0]
|
|
W = H if len(val) == 1 else val[1]
|
|
return H, W
|
|
|
|
kH, kW = unpack("kernel_size", kernel_size)
|
|
check(
|
|
len(stride) in [0, 1, 2],
|
|
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
|
)
|
|
if len(stride) == 0:
|
|
dH, dW = kH, kW
|
|
elif len(stride) == 1:
|
|
dH, dW = stride[0], stride[0]
|
|
else:
|
|
dH, dW = unpack("stride", stride)
|
|
|
|
padH, padW = unpack("padding", padding)
|
|
|
|
check(
|
|
divisor_override is None or divisor_override != 0,
|
|
lambda: "divisor must be not zero",
|
|
)
|
|
|
|
nbatch = input.size(-4) if input.dim() == 4 else 1
|
|
nInputPlane = input.size(-3)
|
|
inputHeight = input.size(-2)
|
|
inputWidth = input.size(-1)
|
|
|
|
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
|
|
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
|
|
|
|
memory_format = utils.suggest_memory_format(input)
|
|
pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
1,
|
|
1,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
memory_format,
|
|
)
|
|
|
|
if input.dim() == 3:
|
|
size = [nInputPlane, outputHeight, outputWidth]
|
|
else:
|
|
size = [nbatch, nInputPlane, outputHeight, outputWidth]
|
|
return torch.empty(
|
|
size, dtype=input.dtype, device=input.device, memory_format=memory_format
|
|
)
|
|
|
|
|
|
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
|
|
def avg_pool2d_backward_shape_check(
|
|
input,
|
|
gradOutput,
|
|
nbatch,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
mem_format,
|
|
):
|
|
pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
1,
|
|
1,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
mem_format,
|
|
)
|
|
|
|
ndim = input.dim()
|
|
nOutputPlane = nInputPlane
|
|
|
|
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
|
|
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
|
|
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
|
|
|
|
|
|
# Don't override the C++ registration.
|
|
@register_meta(aten.avg_pool2d_backward.default)
|
|
def meta_avg_pool2d_backward(
|
|
gradOutput_,
|
|
input,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
ceil_mode,
|
|
count_include_pad,
|
|
divisor_override,
|
|
):
|
|
# From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
|
|
check(
|
|
len(kernel_size) == 1 or len(kernel_size) == 2,
|
|
lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
|
|
)
|
|
kH = kernel_size[0]
|
|
kW = kH if len(kernel_size) == 1 else kernel_size[1]
|
|
check(
|
|
len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
|
|
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
|
)
|
|
dH = kH if len(stride) == 0 else stride[0]
|
|
dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
|
|
check(
|
|
len(padding) == 1 or len(padding) == 2,
|
|
lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
|
|
)
|
|
padH = padding[0]
|
|
padW = padH if len(padding) == 1 else padding[1]
|
|
|
|
check(
|
|
divisor_override is None or divisor_override != 0,
|
|
lambda: "divisor must be not zero",
|
|
)
|
|
|
|
input_size = input.shape
|
|
nbatch = input_size[-4] if input.dim() == 4 else 1
|
|
nInputPlane = input_size[-3]
|
|
inputHeight = input_size[-2]
|
|
inputWidth = input_size[-1]
|
|
|
|
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
|
|
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
|
|
|
|
mem_format = utils.suggest_memory_format(input)
|
|
|
|
avg_pool2d_backward_shape_check(
|
|
input,
|
|
gradOutput_,
|
|
nbatch,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
mem_format,
|
|
)
|
|
|
|
return torch.empty(
|
|
input_size, dtype=input.dtype, device=input.device, memory_format=mem_format
|
|
)
|
|
|
|
|
|
@register_meta(aten._adaptive_avg_pool2d.default)
|
|
def meta_adaptive_avg_pool2d(self, output_size):
|
|
check(
|
|
self.ndim == 3 or self.ndim == 4,
|
|
lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
|
|
)
|
|
output_shape = self.shape[:-2] + tuple(output_size)
|
|
memory_format = utils.suggest_memory_format(self)
|
|
# need to set memory_format to preserve the memory format of the input
|
|
# channel last input should have channel last output
|
|
return torch.empty(
|
|
output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format
|
|
)
|
|
|
|
|
|
@register_meta(aten._adaptive_avg_pool3d.default)
|
|
def meta_adaptive_avg_pool3d(self, output_size):
|
|
check(
|
|
self.ndim == 4 or self.ndim == 5,
|
|
lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
|
|
)
|
|
return self.new_empty(self.shape[:-3] + tuple(output_size))
|
|
|
|
|
|
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
|
def meta__adaptive_avg_pool2d_backward(grad_out, self):
|
|
ndim = grad_out.ndim
|
|
for i in range(1, ndim):
|
|
check(
|
|
grad_out.size(i) > 0,
|
|
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
|
|
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
|
|
)
|
|
check(
|
|
ndim == 3 or ndim == 4,
|
|
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
|
|
)
|
|
check(
|
|
self.dtype == grad_out.dtype,
|
|
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
|
|
)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.repeat_interleave.Tensor)
|
|
def meta_repeat_interleave_Tensor(repeats, output_size=None):
|
|
if output_size is None:
|
|
raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
|
|
return repeats.new_empty(output_size)
|
|
|
|
|
|
@register_meta([aten.complex.default, aten.complex.out])
|
|
@out_wrapper()
|
|
def meta_complex(real, imag):
|
|
assert real.dtype.is_floating_point
|
|
assert imag.dtype.is_floating_point
|
|
out_shape = _broadcast_shapes(real.shape, imag.shape)
|
|
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
|
|
|
|
|
|
@register_meta(aten.vdot.default)
|
|
def vdot(self, other):
|
|
if not self.is_complex:
|
|
return torch.dot(self, other)
|
|
|
|
if self.is_conj():
|
|
if other.is_conj():
|
|
return torch.vdot(other.conj(), self.conj())
|
|
else:
|
|
return torch.dot(self.conj(), other)
|
|
elif other.is_conj():
|
|
return torch.dot(self, other.conj()).conj()
|
|
|
|
dot_check(self, other)
|
|
return self.new_empty(())
|
|
|
|
|
|
# Leaving this function around because a python implementation
|
|
# of indexing shape inference is useful,
|
|
# but not registering it to the dispatcher because we already
|
|
# get shape inference through structured kernels
|
|
@register_meta(aten.index.Tensor)
|
|
def meta_index_Tensor(self, indices):
|
|
check(indices, lambda: "at least one index must be provided")
|
|
# aten::index is the internal advanced indexing implementation
|
|
# checkIndexTensorTypes and expandTensors
|
|
result: List[Optional[Tensor]] = []
|
|
for i, index in enumerate(indices):
|
|
if index is not None:
|
|
check(
|
|
index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
|
|
lambda: "tensors used as indices must be long, int, byte or bool tensors",
|
|
)
|
|
if index.dtype in [torch.int8, torch.bool]:
|
|
nonzero = index.nonzero()
|
|
k = len(result)
|
|
check(
|
|
k + index.ndim <= self.ndim,
|
|
lambda: f"too many indices for tensor of dimension {self.ndim}",
|
|
IndexError,
|
|
)
|
|
for j in range(index.ndim):
|
|
check(
|
|
index.shape[j] == self.shape[k + j],
|
|
lambda: f"The shape of the mask {index.shape} at index {i} "
|
|
f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
|
|
IndexError,
|
|
)
|
|
result.append(nonzero.select(1, j))
|
|
else:
|
|
result.append(index)
|
|
else:
|
|
result.append(index)
|
|
indices = result
|
|
check(
|
|
len(indices) <= self.ndim,
|
|
lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
|
|
)
|
|
# expand_outplace
|
|
import torch._refs as refs # avoid import cycle in mypy
|
|
|
|
indices = list(refs._maybe_broadcast(*indices))
|
|
# add missing null tensors
|
|
while len(indices) < self.ndim:
|
|
indices.append(None)
|
|
|
|
# hasContiguousSubspace
|
|
# true if all non-null tensors are adjacent
|
|
# See:
|
|
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
|
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
|
state = 0
|
|
has_contiguous_subspace = False
|
|
for index in indices:
|
|
if state == 0:
|
|
if index is not None:
|
|
state = 1
|
|
elif state == 1:
|
|
if index is None:
|
|
state = 2
|
|
else:
|
|
if index is not None:
|
|
break
|
|
else:
|
|
has_contiguous_subspace = True
|
|
|
|
# transposeToFront
|
|
# This is the logic that causes the newly inserted dimensions to show up
|
|
# at the beginning of the tensor, if they're not contiguous
|
|
if not has_contiguous_subspace:
|
|
dims = []
|
|
transposed_indices = []
|
|
for i, index in enumerate(indices):
|
|
if index is not None:
|
|
dims.append(i)
|
|
transposed_indices.append(index)
|
|
for i, index in enumerate(indices):
|
|
if index is None:
|
|
dims.append(i)
|
|
transposed_indices.append(index)
|
|
self = self.permute(dims)
|
|
indices = transposed_indices
|
|
|
|
# AdvancedIndex::AdvancedIndex
|
|
# Now we can assume the indices have contiguous subspace
|
|
# This is simplified from AdvancedIndex which goes to more effort
|
|
# to put the input and indices in a form so that TensorIterator can
|
|
# take them. If we write a ref for this, probably that logic should
|
|
# get implemented
|
|
before_shape: List[int] = []
|
|
after_shape: List[int] = []
|
|
replacement_shape: List[int] = []
|
|
for dim, index in enumerate(indices):
|
|
if index is None:
|
|
if replacement_shape:
|
|
after_shape.append(self.shape[dim])
|
|
else:
|
|
before_shape.append(self.shape[dim])
|
|
else:
|
|
replacement_shape = list(index.shape)
|
|
return self.new_empty(before_shape + replacement_shape + after_shape)
|
|
|
|
|
|
@register_meta([aten.convolution_backward.default])
|
|
def meta_convolution_backward(
|
|
grad_output_,
|
|
input_,
|
|
weight_,
|
|
bias_sizes_opt,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
output_mask,
|
|
):
|
|
# High level logic taken from slow_conv3d_backward_cpu which should
|
|
# be representative of all convolution_backward impls
|
|
backend_grad_input = None
|
|
backend_grad_weight = None
|
|
backend_grad_bias = None
|
|
|
|
if output_mask[0]:
|
|
backend_grad_input = grad_output_.new_empty(input_.size())
|
|
if output_mask[1]:
|
|
backend_grad_weight = grad_output_.new_empty(weight_.size())
|
|
if output_mask[2]:
|
|
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
|
|
|
|
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
|
|
|
|
|
|
@register_meta([aten.addbmm.default, aten.addbmm.out])
|
|
@out_wrapper()
|
|
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
|
|
dim1 = batch1.size(1)
|
|
dim2 = batch2.size(2)
|
|
self = self.expand((dim1, dim2))
|
|
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
|
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
|
check(
|
|
batch1.size(0) == batch2.size(0),
|
|
lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
|
|
)
|
|
check(
|
|
batch1.size(2) == batch2.size(1),
|
|
lambda: (
|
|
f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
|
|
f"and {batch2.size(1)}x{batch2.size(2)})"
|
|
),
|
|
)
|
|
check(
|
|
self.size(0) == dim1 and self.size(1) == dim2,
|
|
lambda: "self tensor does not match matmul output shape",
|
|
)
|
|
return self.new_empty(self.size())
|
|
|
|
|
|
@register_meta(aten._cdist_forward.default)
|
|
def meta_cdist_forward(x1, x2, p, compute_mode):
|
|
check(
|
|
x1.dim() >= 2,
|
|
lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
|
|
)
|
|
check(
|
|
x2.dim() >= 2,
|
|
lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
|
|
)
|
|
check(
|
|
x1.size(-1) == x2.size(-1),
|
|
lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
|
|
)
|
|
check(
|
|
utils.is_float_dtype(x1.dtype),
|
|
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
|
|
)
|
|
check(
|
|
utils.is_float_dtype(x2.dtype),
|
|
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
|
|
)
|
|
check(p >= 0, lambda: "cdist only supports non-negative p values")
|
|
check(
|
|
compute_mode in (None, 1, 2),
|
|
lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
|
|
)
|
|
r1 = x1.size(-2)
|
|
r2 = x2.size(-2)
|
|
batch_tensor1 = x1.shape[:-2]
|
|
batch_tensor2 = x2.shape[:-2]
|
|
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
|
|
output_shape.extend([r1, r2])
|
|
return x1.new_empty(output_shape)
|
|
|
|
|
|
@register_meta(aten._embedding_bag.default)
|
|
def meta_embedding_bag(
|
|
weight,
|
|
indices,
|
|
offsets,
|
|
scale_grad_by_freq=False,
|
|
mode=0,
|
|
sparse=False,
|
|
per_sample_weights=None,
|
|
include_last_offset=False,
|
|
padding_idx=-1,
|
|
):
|
|
check(
|
|
indices.dtype in (torch.long, torch.int),
|
|
lambda: f"expected indices to be long or int, got {indices.dtype}",
|
|
)
|
|
check(
|
|
offsets.dtype in (torch.long, torch.int),
|
|
lambda: f"expected offsets to be long or int, got {offsets.dtype}",
|
|
)
|
|
check(
|
|
utils.is_float_dtype(weight.dtype),
|
|
lambda: f"expected weight to be floating point type, got {weight.dtype}",
|
|
)
|
|
|
|
num_bags = offsets.size(0)
|
|
if include_last_offset:
|
|
check(
|
|
num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
|
|
)
|
|
num_bags -= 1
|
|
|
|
output = weight.new_empty(num_bags, weight.size(1))
|
|
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
|
|
|
|
if per_sample_weights is not None:
|
|
check(
|
|
mode == MODE_SUM,
|
|
lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
|
|
)
|
|
check(
|
|
per_sample_weights.dtype == weight.dtype,
|
|
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
|
|
)
|
|
check(
|
|
per_sample_weights.ndim == 1,
|
|
lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
|
|
)
|
|
check(
|
|
per_sample_weights.numel() == indices.numel(),
|
|
lambda: (
|
|
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
|
|
f"to be the same as indices.numel() ({indices.numel()})"
|
|
),
|
|
)
|
|
|
|
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
|
|
return (
|
|
is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
|
|
)
|
|
|
|
def is_fast_path_index_select(src, output, padding_idx):
|
|
return (
|
|
(src.dtype == torch.float or src.dtype == torch.half)
|
|
and src.stride(1) == 1
|
|
and output.stride(1) == 1
|
|
and padding_idx < 0
|
|
)
|
|
|
|
def is_fast_path(src, scale, output, padding_idx):
|
|
if scale is not None:
|
|
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
|
|
else:
|
|
return is_fast_path_index_select(src, output, padding_idx)
|
|
|
|
if device_hint(offsets) != "cpu":
|
|
offset2bag = indices.new_empty(indices.size(0))
|
|
bag_size = indices.new_empty(offsets.size())
|
|
if mode == MODE_MAX:
|
|
max_indices = indices.new_empty(num_bags, weight.size(1))
|
|
else:
|
|
max_indices = indices.new_empty(0)
|
|
else:
|
|
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
|
|
if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
|
|
offset2bag = offsets.new_empty(indices.size(0))
|
|
else:
|
|
offset2bag = offsets.new_empty(0)
|
|
bag_size = offsets.new_empty(num_bags)
|
|
# This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
|
|
numBags = offsets.shape[0]
|
|
if mode == MODE_MAX:
|
|
if include_last_offset:
|
|
check(
|
|
numBags >= 1,
|
|
lambda: "include_last_offset: numBags should be at least 1",
|
|
)
|
|
numBags -= 1
|
|
max_indices = offsets.new_empty(numBags, weight.shape[1])
|
|
else:
|
|
max_indices = offsets.new_empty(bag_size.size())
|
|
return output, offset2bag, bag_size, max_indices
|
|
|
|
|
|
@register_meta(aten._embedding_bag_forward_only.default)
|
|
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
|
|
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
|
|
weight, indices, offsets, *args
|
|
)
|
|
if device_hint(offsets) == "cpu":
|
|
bag_size = offsets.new_empty(offsets.size())
|
|
return output, offset2bag, bag_size, max_indices
|
|
|
|
|
|
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
|
|
# if specified, dtype takes precedence
|
|
if dtype:
|
|
return dtype
|
|
|
|
if input.dtype.is_floating_point or input.dtype.is_complex:
|
|
return input.dtype
|
|
elif promote_int_to_long:
|
|
return torch.long
|
|
|
|
return input.dtype
|
|
|
|
|
|
@register_meta([aten.nansum.default, aten.nansum.out])
|
|
@out_wrapper()
|
|
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
|
|
output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
|
|
dims = utils.reduction_dims(input.shape, dims)
|
|
output_shape = _compute_reduction_shape(input, dims, keepdim)
|
|
return input.new_empty(output_shape, dtype=output_dtype)
|
|
|
|
|
|
@register_meta(aten.nanmedian.default)
|
|
def meta_nanmedian(input):
|
|
output_shape = utils.compute_reduction_output_shape(
|
|
input.shape, tuple(range(input.dim()))
|
|
)
|
|
return input.new_empty(output_shape)
|
|
|
|
|
|
@register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values])
|
|
@out_wrapper("values", "indices")
|
|
def meta_nanmedian_dim(input, dim=-1, keepdim=False):
|
|
dim = utils.reduction_dims(input.shape, (dim,))
|
|
output_shape = _compute_reduction_shape(input, dim, keepdim)
|
|
return (
|
|
input.new_empty(output_shape),
|
|
input.new_empty(output_shape, dtype=torch.long),
|
|
)
|
|
|
|
|
|
@register_meta(aten.logical_not_.default)
|
|
def meta_logical_not_(self):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.repeat.default)
|
|
def meta_repeat(self, repeats):
|
|
check(
|
|
len(repeats) >= self.dim(),
|
|
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
|
|
)
|
|
# Add new leading dimensions to the tensor if the
|
|
# number of target dimensions is larger than the
|
|
# number of source dimensions.
|
|
num_new_dimensions = len(repeats) - self.dim()
|
|
padded_size = (1,) * num_new_dimensions + tuple(self.shape)
|
|
target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
|
|
return self.new_empty(target_size)
|
|
|
|
|
|
@register_meta(aten.zero_.default)
|
|
def meta_zero_(self):
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.mul_.Scalar,
|
|
aten.div_.Scalar,
|
|
aten.mul_.Tensor,
|
|
aten.div_.Tensor,
|
|
aten.logical_and_.default,
|
|
aten.logical_or_.default,
|
|
aten.logical_xor_.default,
|
|
],
|
|
)
|
|
def meta_binop_inplace(self, other):
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.add_.Scalar,
|
|
aten.sub_.Scalar,
|
|
aten.add_.Tensor,
|
|
aten.sub_.Tensor,
|
|
],
|
|
)
|
|
def meta_binop_inplace_alpha(self, other, alpha=1):
|
|
return self
|
|
|
|
|
|
@register_meta([aten.round.default, aten.round.decimals])
|
|
def meta_round(self, **kwargs):
|
|
return _elementwise_meta(
|
|
self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
|
|
)
|
|
|
|
|
|
@register_meta(aten.zero.default)
|
|
def meta_zero(self):
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
|
|
def meta_fill_(self, val):
|
|
return self
|
|
|
|
|
|
@register_meta([aten.fill.Tensor, aten.fill.Scalar])
|
|
def meta_fill(self, val):
|
|
return torch.empty_like(self)
|
|
|
|
|
|
@register_meta(aten.relu_.default)
|
|
def meta_relu_(self):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.index_put.default)
|
|
def meta_index_put(self, indices, values, accumulate=False):
|
|
return torch.empty_like(self)
|
|
|
|
|
|
@register_meta(aten.masked_fill_.Scalar)
|
|
def meta_masked_fill_(self, mask, value):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.index_put_.default)
|
|
def meta_index_put_(self, indices, values, accumulate=False):
|
|
return self
|
|
|
|
|
|
@register_meta(aten.alias.default)
|
|
def meta_alias(self):
|
|
return self.view(self.shape)
|
|
|
|
|
|
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
|
|
check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
|
check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
|
|
|
batch1_sizes = batch1.size()
|
|
batch2_sizes = batch2.size()
|
|
|
|
bs = batch1_sizes[0]
|
|
contraction_size = batch1_sizes[2]
|
|
res_rows = batch1_sizes[1]
|
|
res_cols = batch2_sizes[2]
|
|
output_size = (bs, res_rows, res_cols)
|
|
|
|
check(
|
|
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
|
|
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
|
|
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
|
|
)
|
|
|
|
# TODO: handle out
|
|
|
|
output = batch2.new_empty(output_size)
|
|
|
|
if not is_bmm and self_baddbmm is not None:
|
|
check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
|
|
check(
|
|
self_baddbmm.size() == output_size,
|
|
lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
@register_meta(aten.bmm.default)
|
|
def meta_bmm(self, mat2):
|
|
return common_meta_baddbmm_bmm(self, mat2, True)
|
|
|
|
|
|
def div_rtn(x, y):
|
|
q = x // y
|
|
r = x % y
|
|
# WARNING: explicit bool conversion here is necessary;
|
|
# would be fixed by SymBool
|
|
if r != 0 and (bool(r < 0) != bool(y < 0)):
|
|
q -= 1
|
|
return q
|
|
|
|
|
|
def pooling_output_shape_pad_lr(
|
|
inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
|
|
):
|
|
outputSize = (
|
|
div_rtn(
|
|
inputSize
|
|
+ pad_l
|
|
+ pad_r
|
|
- dilation * (kernelSize - 1)
|
|
- 1
|
|
+ (stride - 1 if ceil_mode else 0),
|
|
stride,
|
|
)
|
|
+ 1
|
|
)
|
|
if ceil_mode:
|
|
if (outputSize - 1) * stride >= inputSize + pad_l:
|
|
outputSize -= 1
|
|
return outputSize
|
|
|
|
|
|
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
|
|
check(stride != 0, lambda: "stride should not be zero")
|
|
check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
|
|
check(
|
|
pad <= kernelSize // 2,
|
|
lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
|
|
)
|
|
return pooling_output_shape_pad_lr(
|
|
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
|
|
)
|
|
|
|
|
|
def pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
dilationH,
|
|
dilationW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
memory_format,
|
|
):
|
|
ndim = input.dim()
|
|
nOutputPlane = nInputPlane
|
|
|
|
check(
|
|
kW > 0 and kH > 0,
|
|
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
|
|
)
|
|
check(
|
|
dW > 0 and dH > 0,
|
|
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
|
|
)
|
|
check(
|
|
dilationH > 0 and dilationW > 0,
|
|
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
|
|
)
|
|
|
|
valid_dims = input.size(1) != 0 and input.size(2) != 0
|
|
|
|
if memory_format == torch.channels_last:
|
|
check(
|
|
ndim == 4 and valid_dims and input.size(3) != 0,
|
|
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
|
|
" with optional 0 dim batch size for input, but got: {input.size()}",
|
|
)
|
|
else:
|
|
check(
|
|
(ndim == 3 and input.size(0) != 0 and valid_dims)
|
|
or (ndim == 4 and valid_dims and input.size(3) != 0),
|
|
lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
|
|
)
|
|
|
|
check(
|
|
kW // 2 >= padW and kH // 2 >= padH,
|
|
lambda: "pad should be smaller than or equal to half of kernel size, but got "
|
|
f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
|
|
)
|
|
|
|
check(
|
|
outputWidth >= 1 and outputHeight >= 1,
|
|
lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
|
|
f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
|
|
"Output size is too small",
|
|
)
|
|
|
|
|
|
def max_pool2d_checks_and_compute_shape(
|
|
input, kernel_size, stride, padding, dilation, ceil_mode
|
|
):
|
|
# Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
|
|
def unpack(name, val):
|
|
check(
|
|
len(val) in [1, 2],
|
|
lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
|
|
)
|
|
H = val[0]
|
|
W = H if len(val) == 1 else val[1]
|
|
return H, W
|
|
|
|
kH, kW = unpack("kernel_size", kernel_size)
|
|
|
|
check(
|
|
len(stride) in [0, 1, 2],
|
|
lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
|
)
|
|
if len(stride) == 0:
|
|
dH, dW = kH, kW
|
|
else:
|
|
dH, dW = unpack("stride", stride)
|
|
|
|
padH, padW = unpack("padding", padding)
|
|
dilationH, dilationW = unpack("dilation", dilation)
|
|
nInputPlane = input.size(-3)
|
|
inputHeight = input.size(-2)
|
|
inputWidth = input.size(-1)
|
|
|
|
memory_format = utils.suggest_memory_format(input)
|
|
if memory_format == torch.channels_last:
|
|
check(
|
|
input.dim() == 4,
|
|
lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
|
|
)
|
|
elif memory_format == torch.contiguous_format:
|
|
check(
|
|
input.dim() in [3, 4],
|
|
lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
|
|
)
|
|
else:
|
|
check(
|
|
False,
|
|
lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
|
|
)
|
|
|
|
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
|
|
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
|
|
|
|
pool2d_shape_check(
|
|
input,
|
|
kH,
|
|
kW,
|
|
dH,
|
|
dW,
|
|
padH,
|
|
padW,
|
|
dilationH,
|
|
dilationW,
|
|
nInputPlane,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
memory_format,
|
|
)
|
|
|
|
return nInputPlane, outputHeight, outputWidth
|
|
|
|
|
|
@register_meta(aten.max_pool2d_with_indices_backward.default)
|
|
def meta_max_pool2d_with_indices_backward(
|
|
grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
|
|
):
|
|
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
|
|
self, kernel_size, stride, padding, dilation, ceil_mode
|
|
)
|
|
|
|
check(
|
|
self.dtype == grad_output.dtype,
|
|
lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
|
|
)
|
|
|
|
nOutputPlane = nInputPlane
|
|
ndim = self.ndim
|
|
|
|
def _check_dim_size(t):
|
|
check_dim_size(t, ndim, ndim - 3, nOutputPlane)
|
|
check_dim_size(t, ndim, ndim - 2, outputHeight)
|
|
check_dim_size(t, ndim, ndim - 1, outputWidth)
|
|
|
|
_check_dim_size(grad_output)
|
|
_check_dim_size(indices)
|
|
|
|
memory_format = utils.suggest_memory_format(self)
|
|
return torch.empty(
|
|
self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format
|
|
)
|
|
|
|
|
|
@register_meta(aten.max_pool2d_with_indices.default)
|
|
def meta_max_pool2d_with_indices(
|
|
input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
|
|
):
|
|
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
|
|
input, kernel_size, stride, padding, dilation, ceil_mode
|
|
)
|
|
|
|
nbatch = input.size(-4) if input.dim() == 4 else 1
|
|
memory_format = utils.suggest_memory_format(input)
|
|
if input.dim() == 3:
|
|
size = [nInputPlane, outputHeight, outputWidth]
|
|
else:
|
|
size = [nbatch, nInputPlane, outputHeight, outputWidth]
|
|
return (
|
|
torch.empty(
|
|
size, dtype=input.dtype, device=input.device, memory_format=memory_format
|
|
),
|
|
torch.empty(
|
|
size, dtype=torch.int64, device=input.device, memory_format=memory_format
|
|
),
|
|
)
|
|
|
|
|
|
@register_meta(aten.grid_sampler_2d_backward.default)
|
|
def grid_sampler_2d_backward_meta(
|
|
grad_output,
|
|
input,
|
|
grid,
|
|
interpolation_mode,
|
|
padding_mode,
|
|
align_corners,
|
|
output_mask,
|
|
):
|
|
input_requires_grad = output_mask[0]
|
|
if input_requires_grad:
|
|
grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
|
|
else:
|
|
grad_input = None
|
|
grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
|
|
return (grad_input, grad_grid)
|
|
|
|
|
|
@register_meta([aten.full.default])
|
|
def full(size, fill_value, *args, **kwargs):
|
|
return torch.empty(size, *args, **kwargs)
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.randint_like.default,
|
|
aten.randint_like.low_dtype,
|
|
aten.randn_like.default,
|
|
aten.rand_like.default,
|
|
aten.full_like.default,
|
|
aten.ones_like.default,
|
|
]
|
|
)
|
|
def meta_like(self, *args, **kwargs):
|
|
return aten.empty_like.default(self, **kwargs)
|
|
|
|
|
|
# zeros_like is special cased to work for sparse
|
|
@register_meta(aten.zeros_like.default)
|
|
def zeros_like(
|
|
self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
|
|
):
|
|
if layout == torch.sparse_coo:
|
|
check(
|
|
memory_format is None,
|
|
lambda: "memory format option is only supported by strided tensors",
|
|
)
|
|
|
|
res = torch.empty(
|
|
0,
|
|
dtype=self.dtype if dtype is None else dtype,
|
|
layout=layout,
|
|
device=self.device if device is None else device,
|
|
pin_memory=pin_memory,
|
|
)
|
|
|
|
if self.is_sparse:
|
|
res.sparse_resize_and_clear_(
|
|
self.size(), self.sparse_dim(), self.dense_dim()
|
|
)
|
|
else:
|
|
res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
|
|
|
|
res._coalesced_(True)
|
|
return res
|
|
return aten.empty_like.default(
|
|
self,
|
|
dtype=dtype,
|
|
layout=layout,
|
|
device=device,
|
|
pin_memory=pin_memory,
|
|
memory_format=memory_format,
|
|
)
|
|
|
|
|
|
@register_meta(aten.select.int)
|
|
def meta_select(self, dim, index):
|
|
ndim = self.dim()
|
|
check(
|
|
ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError
|
|
)
|
|
|
|
dim = dim if dim >= 0 else dim + ndim
|
|
size = self.size(dim)
|
|
|
|
check(
|
|
not (-index > size or index >= size),
|
|
lambda: f"select(): index {index} out of range for tensor of size "
|
|
f"{self.size()} at dimension {dim}",
|
|
IndexError,
|
|
)
|
|
|
|
index = index if index >= 0 else index + size
|
|
|
|
new_size = list(self.size())
|
|
new_stride = list(self.stride())
|
|
|
|
new_storage_offset = self.storage_offset() + index * new_stride[dim]
|
|
del new_size[dim]
|
|
del new_stride[dim]
|
|
|
|
return self.as_strided(new_size, new_stride, new_storage_offset)
|
|
|
|
|
|
@register_meta(aten.select_scatter.default)
|
|
def meta_select_scatter(self, src, dim, index):
|
|
return utils.clone_preserve_strides(self)
|
|
|
|
|
|
@register_meta(aten.slice_scatter.default)
|
|
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
|
|
return utils.clone_preserve_strides(self)
|
|
|
|
|
|
# TODO: Deduplicate this with canonicalize_dim
|
|
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
|
|
if dim_post_expr <= 0:
|
|
assert wrap_scalar
|
|
dim_post_expr = 1
|
|
min = -dim_post_expr
|
|
max = dim_post_expr - 1
|
|
assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
|
|
if dim < 0:
|
|
dim += dim_post_expr
|
|
return dim
|
|
|
|
|
|
def ensure_nonempty_size(t, dim):
|
|
return 1 if t.dim() == 0 else t.shape[dim]
|
|
|
|
|
|
# From aten/src/ATen/native/ScatterGatherChecks.h
|
|
def gather_shape_check(self, dim, index):
|
|
self_dims = max(self.dim(), 1)
|
|
index_dims = max(index.dim(), 1)
|
|
check(
|
|
self_dims == index_dims,
|
|
lambda: "Index tensor must have the same number of dimensions as input tensor",
|
|
)
|
|
for i in range(self_dims):
|
|
if i != dim:
|
|
check(
|
|
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
|
|
lambda: f"Size does not match at dimension {i} expected index {index.shape}"
|
|
+ f" to be smaller than self {self.shape} apart from dimension {dim}",
|
|
)
|
|
|
|
|
|
@register_meta(aten.gather.default)
|
|
def meta_gather(self, dim, index, sparse_grad=False):
|
|
wrapped_dim = maybe_wrap_dim(dim, self.dim())
|
|
is_index_empty = index.numel() == 0
|
|
if not is_index_empty:
|
|
check(
|
|
index.dtype == torch.long,
|
|
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
|
|
)
|
|
gather_shape_check(self, wrapped_dim, index)
|
|
return self.new_empty(index.shape)
|
|
|
|
|
|
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
|
|
def get_operator_enum(reduce_, use_new_options=False):
|
|
if use_new_options:
|
|
if reduce_ == "sum":
|
|
return "REDUCE_ADD"
|
|
elif reduce_ == "prod":
|
|
return "REDUCE_MULTIPLY"
|
|
elif reduce_ == "mean":
|
|
return "REDUCE_MEAN"
|
|
elif reduce_ == "amax":
|
|
return "REDUCE_MAXIMUM"
|
|
elif reduce_ == "amin":
|
|
return "REDUCE_MINIMUM"
|
|
check(
|
|
False,
|
|
lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
|
|
)
|
|
return
|
|
else:
|
|
if reduce_ == "add":
|
|
return "REDUCE_ADD"
|
|
elif reduce_ == "multiply":
|
|
return "REDUCE_MULTIPLY"
|
|
check(False, lambda: "reduce argument must be either add or multiply.")
|
|
return
|
|
|
|
|
|
# From aten/src/ATen/native/ScatterGatherChecks.h
|
|
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
|
|
if index.numel() != 0:
|
|
check(
|
|
index.dtype == torch.long,
|
|
lambda: f"{method_name}(): Expected dtype int64 for index",
|
|
)
|
|
|
|
if src_opt is not None:
|
|
check(
|
|
self.dtype == src_opt.dtype,
|
|
lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
|
|
)
|
|
|
|
|
|
def ensure_nonempty_dim(dim):
|
|
return max(dim, 1)
|
|
|
|
|
|
# From aten/src/ATen/native/ScatterGatherChecks.h
|
|
def scatter_shape_check(self, dim, index, src_opt=None):
|
|
if index.numel() == 0:
|
|
return
|
|
check(
|
|
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
|
lambda: "Index tensor must have the same number of dimensions as self tensor",
|
|
)
|
|
|
|
is_wrong_shape = False
|
|
self_dims = ensure_nonempty_dim(self.dim())
|
|
|
|
# Check: index.size(d) <= self.size(d) for all d != dim
|
|
for d in range(self_dims):
|
|
index_d_size = ensure_nonempty_size(index, d)
|
|
if d == dim:
|
|
continue
|
|
if index_d_size > ensure_nonempty_size(self, d):
|
|
is_wrong_shape = True
|
|
break
|
|
|
|
# Check: index.size(d) <= src.size(d) for all d if src is Tensor
|
|
if not is_wrong_shape and src_opt is not None:
|
|
for d in range(self_dims):
|
|
index_d_size = ensure_nonempty_size(index, d)
|
|
if index_d_size > ensure_nonempty_size(src_opt, d):
|
|
is_wrong_shape = True
|
|
break
|
|
|
|
if src_opt is not None:
|
|
check(
|
|
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
|
lambda: "Index tensor must have the same number of dimensions as self tensor",
|
|
)
|
|
check(
|
|
not is_wrong_shape,
|
|
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
|
|
+ f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
|
|
)
|
|
else:
|
|
check(
|
|
not is_wrong_shape,
|
|
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
|
|
+ f" apart from dimension {dim}",
|
|
)
|
|
|
|
|
|
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
|
|
def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
|
|
wrapped_dim = maybe_wrap_dim(dim, self.dim())
|
|
scatter_gather_dtype_check("scatter", self, index, src)
|
|
scatter_shape_check(self, wrapped_dim, index, src)
|
|
if reduce_ is not None:
|
|
# Check if we have a valid reduce operator.
|
|
get_operator_enum(reduce_, use_new_options)
|
|
|
|
|
|
@register_meta(aten.scatter_add.default)
|
|
def meta_scatter_add(self, dim, index, src):
|
|
scatter_meta_impl(self, dim, index, src, "add")
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.scatter_add_)
|
|
def meta_scatter_add_(self, dim, index, src):
|
|
scatter_meta_impl(self, dim, index, src, "add")
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.scatter.src,
|
|
aten.scatter.value,
|
|
aten.scatter.reduce,
|
|
aten.scatter.value_reduce,
|
|
]
|
|
)
|
|
@out_wrapper()
|
|
def meta_scatter(self, dim, index, src_or_value, reduce=None):
|
|
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
|
|
scatter_meta_impl(self, dim, index, src, reduce)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten.scatter_.src,
|
|
aten.scatter_.value,
|
|
aten.scatter_.reduce,
|
|
aten.scatter_.value_reduce,
|
|
]
|
|
)
|
|
def meta_scatter_(self, dim, index, src_or_value, reduce=None):
|
|
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
|
|
scatter_meta_impl(self, dim, index, src, reduce)
|
|
return self
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_flash_attention,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_flash(
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
dropout_p: float = 0.0,
|
|
is_causal: bool = False,
|
|
return_debug_mask: bool = False,
|
|
):
|
|
# [Note] SDPA_flash's meta function returns incorrect Philox seed and offset:
|
|
# We have added logic to torch/_dynamo/variables/torch.py
|
|
# We need to check if scaled_dot_product_attention will run the flash attention
|
|
# kernel and if dropout is != 0.0. If that is the case then we want dynamo
|
|
# to graph break. The derivative calculation for _scaled_dot_product_flash_attention
|
|
# does not function correctly with cuda graphs because the full philox state is not captured
|
|
# the forward's return values. Another reason to graph break is that the the meta function
|
|
# returns the wrong outputs for philox seed and offset and these values get baked into the
|
|
# inductor fallback calls to the eager kernels.
|
|
check(
|
|
dropout_p == 0.0,
|
|
lambda: f"Can only trace _scaled_dot_product_flash_attention when dropout is set to 0 but got a dropout_p of {dropout_p}.",
|
|
)
|
|
batch_size = query.size(0)
|
|
num_heads = query.size(1)
|
|
max_seqlen_batch_q = query.size(2)
|
|
head_dim = query.size(3)
|
|
|
|
max_seqlen_batch_k = key.size(2)
|
|
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
Nnz_q = batch_size * max_seqlen_batch_q
|
|
|
|
output = torch.empty(
|
|
(Nnz_q, num_heads, head_dim), dtype=query.dtype, device=query.device
|
|
)
|
|
output = output.view(batch_size, max_seqlen_batch_q, num_heads, head_dim).transpose(
|
|
1, 2
|
|
)
|
|
max_seqlen_q = math.ceil(max_seqlen_batch_q / 16) * 16
|
|
logsumexp = torch.empty(
|
|
(batch_size, num_heads, max_seqlen_q),
|
|
dtype=torch.float,
|
|
device=query.device,
|
|
)
|
|
cumulative_sequence_length_q = torch.empty(
|
|
batch_size + 1, dtype=torch.int32, device="meta"
|
|
)
|
|
cumulative_sequence_length_k = torch.empty(
|
|
batch_size + 1, dtype=torch.int32, device="meta"
|
|
)
|
|
|
|
if return_debug_mask:
|
|
blocksize_c = 128 if head_dim > 64 else 256
|
|
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
|
|
if max_seqlen_batch_k <= 128:
|
|
max_seqlen_k = 128
|
|
elif max_seqlen_batch_k <= 256:
|
|
max_seqlen_k = 256
|
|
debug_mask = torch.empty(
|
|
(batch_size, num_heads, max_seqlen_q, max_seqlen_k),
|
|
dtype=query.dtype,
|
|
device=query.device,
|
|
)
|
|
else:
|
|
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
|
|
|
return (
|
|
output,
|
|
logsumexp,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_k,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_k,
|
|
1, # Philox Seed will not be used, see note at top.
|
|
1, # Philox Offset will not be used, see note at top.
|
|
debug_mask,
|
|
)
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_flash_attention_backward,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_flash_backward(
|
|
grad_out: Tensor,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
out: Tensor,
|
|
logsumexp: Tensor,
|
|
cum_seq_q: Tensor,
|
|
cum_seq_k: Tensor,
|
|
max_q: int,
|
|
max_k: int,
|
|
dropout_p: float,
|
|
is_causal: bool,
|
|
philox_seed: int,
|
|
philox_offset: int,
|
|
):
|
|
batch_size = query.size(0)
|
|
num_heads = query.size(1)
|
|
head_dim = query.size(3)
|
|
|
|
Nnz_q = batch_size * max_q
|
|
Nnz_kv = batch_size * max_k
|
|
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
query_reshaped = query.reshape(Nnz_q, num_heads, head_dim)
|
|
key_reshaped = key.reshape(Nnz_kv, num_heads, head_dim)
|
|
value_reshaped = value.reshape(Nnz_kv, num_heads, head_dim)
|
|
|
|
grad_q = torch.empty_like(query_reshaped)
|
|
grad_k = torch.empty_like(key_reshaped)
|
|
grad_v = torch.empty_like(value_reshaped)
|
|
|
|
grad_q = grad_q.view(batch_size, max_q, num_heads, head_dim).transpose(1, 2)
|
|
grad_k = grad_k.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
|
|
grad_v = grad_v.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
|
|
|
|
return grad_q, grad_k, grad_v
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_efficient_attention,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_efficient(
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
compute_log_sumexp: bool,
|
|
is_causal: bool = False,
|
|
):
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
B = query.size(0)
|
|
M = query.size(1)
|
|
N = key.size(1)
|
|
num_heads = query.size(-2)
|
|
K = query.size(-1)
|
|
Kv = value.size(-1)
|
|
|
|
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
|
|
|
|
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
|
|
logsum_exp = torch.empty(
|
|
(B, num_heads, logsumexp_dim),
|
|
dtype=torch.float,
|
|
device=query.device,
|
|
)
|
|
|
|
res = res.transpose(1, 2)
|
|
|
|
return res, logsum_exp
|
|
|
|
|
|
@register_meta(
|
|
[
|
|
aten._scaled_dot_product_efficient_attention_backward,
|
|
]
|
|
)
|
|
def meta__scaled_dot_product_efficient_backward(
|
|
grad_out: Tensor,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
out: Tensor,
|
|
logsumexp: Tensor,
|
|
is_causal: bool = False,
|
|
chunk_grad_outputs=False,
|
|
):
|
|
grad_out = grad_out.transpose(1, 2)
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
B = query.size(0)
|
|
M = query.size(1)
|
|
N = key.size(1)
|
|
nH = query.size(2)
|
|
K = query.size(3)
|
|
|
|
grad_kv_needs_init = is_causal and N > M
|
|
|
|
if chunk_grad_outputs:
|
|
chunk = torch.empty((B, M, 3, nH, K), dtype=query.dtype, device=query.device)
|
|
grad_q = chunk.select(2, 0)
|
|
grad_k = chunk.select(2, 1)
|
|
grad_v = chunk.select(2, 2)
|
|
else:
|
|
grad_q = torch.empty(query.shape, dtype=query.dtype, device=query.device)
|
|
grad_k = (
|
|
torch.zeros(key.shape, dtype=key.dtype, device=key.device)
|
|
if grad_kv_needs_init
|
|
else torch.empty(key.shape, dtype=key.dtype, device=key.device)
|
|
)
|
|
grad_v = (
|
|
torch.zeros(value.shape, dtype=value.dtype, device=value.device)
|
|
if grad_kv_needs_init
|
|
else torch.empty(value.shape, dtype=value.dtype, device=value.device)
|
|
)
|
|
return grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)
|
|
|
|
|
|
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
|
|
@out_wrapper()
|
|
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
|
|
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
|
|
return self.new_empty(self.shape)
|
|
|
|
|
|
@register_meta(aten.scatter_reduce_.two)
|
|
def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
|
|
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
|
|
return self
|
|
|
|
|
|
def multiply_integers(vs):
|
|
r = 1
|
|
for v in vs:
|
|
r *= v
|
|
return r
|
|
|
|
|
|
def upsample_common_check(input_size, output_size, num_spatial_dims):
|
|
check(
|
|
len(output_size) == num_spatial_dims,
|
|
lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
|
|
)
|
|
expected_input_dims = num_spatial_dims + 2 # N, C, ...
|
|
check(
|
|
len(input_size) == expected_input_dims,
|
|
lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
|
|
)
|
|
|
|
check(
|
|
all([s > 0 for s in input_size[2:]]) and all([s > 0 for s in output_size]),
|
|
lambda: f"Input and output sizes should be greater than 0, but got "
|
|
f"input size {input_size} and output size {output_size}",
|
|
)
|
|
|
|
nbatch, channels = input_size[:2]
|
|
return (nbatch, channels, *output_size)
|
|
|
|
|
|
@register_meta(aten.upsample_nearest1d.default)
|
|
def upsample_nearest1d(input, output_size, scales=None):
|
|
check(
|
|
input.numel() != 0 or multiply_integers(input.size()[1:]),
|
|
lambda: "Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
|
|
)
|
|
full_output_size = upsample_common_check(
|
|
input.size(), output_size, num_spatial_dims=1
|
|
)
|
|
return input.new_empty(full_output_size).to(
|
|
memory_format=utils.suggest_memory_format(input)
|
|
)
|
|
|
|
|
|
@register_meta(aten.upsample_nearest2d.default)
|
|
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
|
|
check(
|
|
input.numel() != 0 or multiply_integers(input.size()[1:]),
|
|
lambda: "Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
|
|
)
|
|
full_output_size = upsample_common_check(
|
|
input.size(), output_size, num_spatial_dims=2
|
|
)
|
|
output = input.new_empty(full_output_size)
|
|
|
|
# convert output to correct memory format, if necessary
|
|
memory_format = utils.suggest_memory_format(input)
|
|
|
|
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
|
|
_, n_channels, _, _ = input.shape
|
|
if input.device.type == "cuda" and n_channels < 4:
|
|
memory_format = torch.contiguous_format
|
|
|
|
output = output.contiguous(memory_format=memory_format)
|
|
|
|
return output
|
|
|
|
|
|
@register_meta(aten.upsample_nearest3d.default)
|
|
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
|
|
check(
|
|
input.numel() != 0 or multiply_integers(input.size()[1:]),
|
|
lambda: "Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
|
|
)
|
|
full_output_size = upsample_common_check(
|
|
input.size(), output_size, num_spatial_dims=3
|
|
)
|
|
return input.new_empty(full_output_size).to(
|
|
memory_format=utils.suggest_memory_format(input)
|
|
)
|
|
|
|
|
|
@register_meta([aten.sort.default, aten.sort.stable])
|
|
def meta_sort(self, stable=None, dim=-1, descending=False):
|
|
return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
|
|
|
|
|
|
def rnn_cell_checkSizes(
|
|
input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
|
|
):
|
|
check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
|
|
check(
|
|
input_gates.shape == hidden_gates.shape,
|
|
lambda: f"{input_gates.shape} != {hidden_gates.shape}",
|
|
)
|
|
gates_size = input_gates.size(1)
|
|
if input_bias is not None:
|
|
check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
|
|
check(
|
|
input_bias.numel() == gates_size,
|
|
lambda: f"{input_bias.numel()} != {gates_size}",
|
|
)
|
|
check(
|
|
input_bias.shape == hidden_bias.shape,
|
|
lambda: f"{input_bias.shape} != {hidden_bias.shape}",
|
|
)
|
|
check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
|
|
expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
|
|
check(
|
|
prev_hidden.numel() == expected_prev_hidden_numel,
|
|
lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
|
|
)
|
|
check(
|
|
all(
|
|
x.device == input_gates.device
|
|
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
|
|
),
|
|
lambda: "expected all inputs to be same device",
|
|
)
|
|
|
|
|
|
@register_meta(aten._thnn_fused_lstm_cell.default)
|
|
def _thnn_fused_lstm_cell_meta(
|
|
input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
|
|
):
|
|
rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
|
|
workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
|
|
hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
|
|
cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
|
|
return (hy, cy, workspace)
|
|
|
|
|
|
@register_meta(aten._cudnn_rnn.default)
|
|
def _cudnn_rnn(
|
|
input,
|
|
weight,
|
|
weight_stride0,
|
|
weight_buf,
|
|
hx,
|
|
cx,
|
|
mode,
|
|
hidden_size,
|
|
proj_size,
|
|
num_layers,
|
|
batch_first,
|
|
dropout,
|
|
train,
|
|
bidirectional,
|
|
batch_sizes,
|
|
dropout_state,
|
|
):
|
|
|
|
is_input_packed = len(batch_sizes) != 0
|
|
if is_input_packed:
|
|
seq_length = len(batch_sizes)
|
|
mini_batch = batch_sizes[0]
|
|
batch_sizes_sum = input.shape[0]
|
|
else:
|
|
seq_length = input.shape[1] if batch_first else input.shape[0]
|
|
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
|
batch_sizes_sum = -1
|
|
|
|
num_directions = 2 if bidirectional else 1
|
|
out_size = proj_size if proj_size != 0 else hidden_size
|
|
if is_input_packed:
|
|
out_shape = [batch_sizes_sum, out_size * num_directions]
|
|
else:
|
|
out_shape = (
|
|
[mini_batch, seq_length, out_size * num_directions]
|
|
if batch_first
|
|
else [seq_length, mini_batch, out_size * num_directions]
|
|
)
|
|
output = input.new_empty(out_shape)
|
|
|
|
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
|
if cx is None:
|
|
cy = torch.empty(0, device=input.device)
|
|
else:
|
|
cy = cx.new_empty(cell_shape)
|
|
|
|
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
|
|
|
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
|
reserve_shape = 0 if train else 0
|
|
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
|
|
|
return output, hy, cy, reserve, weight_buf
|
|
|
|
|
|
@register_meta(aten.mkldnn_rnn_layer.default)
|
|
def mkldnn_rnn_layer(
|
|
input,
|
|
w0,
|
|
w1,
|
|
w2,
|
|
w3,
|
|
hx_,
|
|
cx_,
|
|
reverse,
|
|
batch_sizes,
|
|
mode,
|
|
hidden_size,
|
|
num_layers,
|
|
has_biases,
|
|
bidirectional,
|
|
batch_first,
|
|
train,
|
|
):
|
|
seq_length = input.shape[1] if batch_first else input.shape[0]
|
|
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
|
output_chanels = hidden_size
|
|
out_shape = (
|
|
[mini_batch, seq_length, output_chanels]
|
|
if batch_first
|
|
else [seq_length, mini_batch, output_chanels]
|
|
)
|
|
output = input.new_empty(out_shape)
|
|
if hx_ is None:
|
|
hy = torch.empty(0, device=input.device)
|
|
else:
|
|
hy = hx_.new_empty(hx_.shape)
|
|
if cx_ is None:
|
|
cy = torch.empty(0, device=input.device)
|
|
else:
|
|
cy = cx_.new_empty(cx_.shape)
|
|
workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
|
|
return output, hy, cy, workspace
|
|
|
|
|
|
def zero_numel_check_dims(self, dim, fn_name):
|
|
if self.ndim == 0:
|
|
check(
|
|
dim == 0 or dim == -1,
|
|
lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
|
|
IndexError,
|
|
)
|
|
else:
|
|
check(
|
|
self.size(dim) != 0,
|
|
lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
|
|
IndexError,
|
|
)
|
|
|
|
|
|
# From aten/src/ATen/native/ReduceOps.cpp
|
|
def check_argmax_argmin(name, self, dim):
|
|
if dim is not None:
|
|
dim = maybe_wrap_dim(dim, self.dim())
|
|
zero_numel_check_dims(self, dim, name)
|
|
else:
|
|
check(
|
|
self.numel() != 0,
|
|
lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
|
|
)
|
|
|
|
|
|
@register_meta([aten.argmax.default, aten.argmin.default])
|
|
def argmax_argmin_meta(self, dim=None, keepdim=False):
|
|
check_argmax_argmin("argmax", self, dim)
|
|
dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
|
|
shape = _compute_reduction_shape(self, dims, keepdim)
|
|
return self.new_empty(shape, dtype=torch.int64)
|
|
|
|
|
|
@register_meta(aten.scalar_tensor.default)
|
|
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
|
|
return torch.empty(
|
|
(), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
|
)
|
|
|
|
|
|
@register_meta(aten.topk.default)
|
|
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
|
|
# From aten/src/ATen/native/Sorting.cpp
|
|
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
|
|
check(
|
|
k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
|
|
lambda: "selected index k out of range",
|
|
)
|
|
sliceSize = 1 if self.dim() == 0 else self.size(dim)
|
|
check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
|
|
|
|
topKSize = list(self.shape)
|
|
if len(topKSize) > 0:
|
|
topKSize[dim] = k
|
|
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
|
|
|
|
|
|
legacy_contiguous_memory_format = torch.contiguous_format
|
|
|
|
|
|
# From aten/src/ATen/native/cuda/RNN.cu
|
|
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
|
|
defined_grad = grad_hy if grad_hy is not None else grad_cy
|
|
check(defined_grad.dim() == 2, lambda: "")
|
|
exp_size = defined_grad.size()
|
|
if grad_hy is not None:
|
|
check(grad_hy.size() == exp_size, lambda: "")
|
|
if grad_cy is not None:
|
|
check(grad_cy.size() == exp_size, lambda: "")
|
|
check(cx.size() == exp_size, lambda: "")
|
|
check(cy.size() == exp_size, lambda: "")
|
|
check(workspace.dim() == 2, lambda: "")
|
|
check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
|
|
|
|
|
|
# From aten/src/ATen/native/cuda/RNN.cu
|
|
@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
|
|
def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
|
|
if grad_hy is None and grad_cy is None:
|
|
return None, None, None
|
|
checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
|
|
grad_gates = torch.empty_like(
|
|
workspace, memory_format=legacy_contiguous_memory_format
|
|
)
|
|
grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
|
|
grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
|
|
return grad_gates, grad_cx, grad_bias
|
|
|
|
|
|
@register_meta(aten.pixel_shuffle.default)
|
|
def meta_pixel_shuffle(self, upscale_factor):
|
|
assert (
|
|
len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
|
|
), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
|
|
|
|
def is_channels_last(ten):
|
|
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
|
|
|
|
def pick_memory_format():
|
|
if is_channels_last(self):
|
|
if device_hint(self) == "cuda":
|
|
return torch.contiguous_format
|
|
else:
|
|
return torch.channels_last
|
|
elif self.is_contiguous(memory_format=torch.contiguous_format):
|
|
return torch.contiguous_format
|
|
elif self.is_contiguous(memory_format=torch.preserve_format):
|
|
return torch.preserve_format
|
|
|
|
C = self.shape[-3] // (upscale_factor * upscale_factor)
|
|
Hr = self.shape[-2] * upscale_factor
|
|
Wr = self.shape[-1] * upscale_factor
|
|
out_shape = (*self.shape[:-3], C, Hr, Wr)
|
|
|
|
out = self.new_empty(out_shape)
|
|
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
|
|
return out
|
|
|
|
|
|
@register_meta(aten.mkldnn_rnn_layer_backward.default)
|
|
def mkldnn_rnn_layer_backward(
|
|
input,
|
|
weight0,
|
|
weight1,
|
|
weight2,
|
|
weight3,
|
|
hx_,
|
|
cx_tmp,
|
|
output,
|
|
hy_,
|
|
cy_,
|
|
grad_output_r_opt,
|
|
grad_hy_r_opt,
|
|
grad_cy_r_opt,
|
|
reverse,
|
|
mode,
|
|
hidden_size,
|
|
num_layers,
|
|
has_biases,
|
|
train,
|
|
bidirectional,
|
|
batch_sizes,
|
|
batch_first,
|
|
workspace,
|
|
):
|
|
diff_x = input.new_empty(input.shape)
|
|
diff_hx = hx_.new_empty(hx_.shape)
|
|
diff_cx = cx_tmp.new_empty(cx_tmp.shape)
|
|
diff_w1 = weight0.new_empty(weight0.shape)
|
|
diff_w2 = weight1.new_empty(weight1.shape)
|
|
diff_b = weight2.new_empty(weight2.shape)
|
|
return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
|
|
|
|
|
|
@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
|
|
@out_wrapper()
|
|
def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
|
|
return torch.empty_like(
|
|
self, dtype=torch.int32 if out_int32 else torch.int64
|
|
).contiguous()
|
|
|
|
|
|
@register_meta(aten._upsample_bilinear2d_aa.default)
|
|
def meta_upsample_bilinear2d_aa(
|
|
input, output_size, align_corners, scales_h=None, scales_w=None
|
|
):
|
|
full_output_size = upsample_common_check(
|
|
input.size(), output_size, num_spatial_dims=2
|
|
)
|
|
check(
|
|
input.numel() != 0 or all([size > 0 for size in input.size()[1:]]),
|
|
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
|
|
)
|
|
return input.new_empty(full_output_size).to(
|
|
memory_format=utils.suggest_memory_format(input)
|
|
)
|
|
|
|
|
|
# We must also trigger meta registrations from PrimTorch ref
|
|
# decompositions
|
|
import torch._refs
|
|
import torch._refs.nn.functional
|
|
import torch._refs.special
|
|
|
|
|
|
def activate_meta():
|
|
|
|
activate_meta_table = {}
|
|
|
|
# For a given op, we pick the most specific decomp function from
|
|
# global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
|
|
for type in ["meta", "post_autograd", "pre_autograd"]:
|
|
registry = global_decomposition_table[type]
|
|
|
|
for opo in registry:
|
|
if opo not in activate_meta_table:
|
|
activate_meta_table[opo] = registry[opo]
|
|
|
|
for op_overload, fn in activate_meta_table.items():
|
|
assert isinstance(op_overload, OpOverload)
|
|
|
|
op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
|
|
|
|
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
|
op_overload.name(), "CompositeImplicitAutograd"
|
|
):
|
|
# Internally, we shouldn't be registering meta kernels for any operators that
|
|
# have CompositeImplicitAutograd kernels.
|
|
# Instead, we should be letting those decompositions run, and writing meta kernels
|
|
# only for the base operators.
|
|
if op_overload in global_decomposition_table["meta"]:
|
|
raise RuntimeError(
|
|
f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
|
|
"register meta function for it. Instead, we should let the decomposition run and write "
|
|
"meta kernels for the base operators."
|
|
)
|
|
pass
|
|
elif op_overload.is_view:
|
|
# Attempting to register a python meta kernel for a view operator.
|
|
# We shouldn't do this, because the output will report as not having aliased storages.
|
|
# All view ops have meta kernels in C++ today, so we should use those instead.
|
|
pass
|
|
elif op_overload.name() in {
|
|
"aten::empty_strided", # causing infinite recursion, test_meta.py
|
|
"aten::clone", # causing infinite recursion
|
|
"aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950
|
|
"aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
|
|
"aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
|
|
"aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
|
|
"aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
|
|
}:
|
|
pass
|
|
else:
|
|
if "mkldnn::" in op_overload.name():
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
|
|
elif "mkl::" in op_overload.name():
|
|
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
|
|
else:
|
|
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
|
|
|
|
|
|
@register_meta(aten.all_reduce)
|
|
def all_reduce_meta(self, reduceOp, tag, rankset, stride):
|
|
return torch.empty_like(self)
|
|
|
|
|
|
@register_meta(aten.wait_tensor)
|
|
def wait_tensor_meta(self):
|
|
return torch.empty_like(self)
|
|
|
|
|
|
activate_meta()
|