mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	This reverts commit 5680f565d5b7d4aa412a3988d3d91ca4c5679303. Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
		
			
				
	
	
		
			1460 lines
		
	
	
		
			44 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1460 lines
		
	
	
		
			44 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import math
 | 
						|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 | 
						|
 | 
						|
number = Union[int, float]
 | 
						|
# flake8: noqa
 | 
						|
 | 
						|
###
 | 
						|
# There are generated files that depend on this file
 | 
						|
# To re-generate, please run from the root of the repo:
 | 
						|
# python torchgen/shape_functions/gen_jit_shape_functions.py
 | 
						|
 | 
						|
# How to test:
 | 
						|
# After regenerating files, compile PyTorch.
 | 
						|
# Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
 | 
						|
# If you have enabled opinfo testing for the op, also run:
 | 
						|
# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
 | 
						|
# to reproduce errors from opinfo tests.
 | 
						|
 | 
						|
# Example PR: https://github.com/pytorch/pytorch/pull/80860/files
 | 
						|
####
 | 
						|
 | 
						|
import torch
 | 
						|
 | 
						|
 | 
						|
def broadcast(a: List[int], b: List[int]):
 | 
						|
    dimsA = len(a)
 | 
						|
    dimsB = len(b)
 | 
						|
    ndim = max(dimsA, dimsB)
 | 
						|
    expandedSizes: List[int] = []
 | 
						|
 | 
						|
    for i in range(ndim):
 | 
						|
        offset = ndim - 1 - i
 | 
						|
        dimA = dimsA - 1 - offset
 | 
						|
        dimB = dimsB - 1 - offset
 | 
						|
        sizeA = a[dimA] if (dimA >= 0) else 1
 | 
						|
        sizeB = b[dimB] if (dimB >= 0) else 1
 | 
						|
 | 
						|
        if sizeA != sizeB and sizeA != 1 and sizeB != 1:
 | 
						|
            # TODO: only assertion error is bound in C++ compilation right now
 | 
						|
            raise AssertionError(
 | 
						|
                f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}"
 | 
						|
            )
 | 
						|
 | 
						|
        expandedSizes.append(sizeB if sizeA == 1 else sizeA)
 | 
						|
 | 
						|
    return expandedSizes
 | 
						|
 | 
						|
 | 
						|
def broadcast_three(a: List[int], b: List[int], c: List[int]):
 | 
						|
    return broadcast(broadcast(a, b), c)
 | 
						|
 | 
						|
 | 
						|
def broadcast_one_three(a: List[int], b: Any, c: List[int]):
 | 
						|
    return broadcast(a, c)
 | 
						|
 | 
						|
 | 
						|
def adaptive_avg_pool2d(self: List[int], out: List[int]):
 | 
						|
    assert len(out) == 2
 | 
						|
    assert len(self) == 3 or len(self) == 4
 | 
						|
    for i in range(1, len(self)):
 | 
						|
        assert self[i] != 0
 | 
						|
 | 
						|
    shape: List[int] = []
 | 
						|
    for i in range(0, len(self) - 2):
 | 
						|
        shape.append(self[i])
 | 
						|
    for elem in out:
 | 
						|
        shape.append(elem)
 | 
						|
    return shape
 | 
						|
 | 
						|
 | 
						|
def _copy(self: List[int]):
 | 
						|
    out: List[int] = []
 | 
						|
    for elem in self:
 | 
						|
        out.append(elem)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def unary(self: List[int]):
 | 
						|
    return _copy(self)
 | 
						|
 | 
						|
 | 
						|
def broadcast_inplace(a: List[int], b: List[int]):
 | 
						|
    dimsA = len(a)
 | 
						|
    dimsB = len(b)
 | 
						|
    if dimsB > dimsA:
 | 
						|
        raise AssertionError(
 | 
						|
            f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) "
 | 
						|
        )
 | 
						|
    for dimA in range(dimsA):
 | 
						|
        dimB = dimsB - dimsA + dimA
 | 
						|
        sizeA = a[dimA]
 | 
						|
        sizeB = b[dimB] if (dimB >= 0) else 1
 | 
						|
        if sizeA != sizeB and sizeB != 1:
 | 
						|
            # TODO: only assertion error is bound in C++ compilation right now
 | 
						|
            raise AssertionError(
 | 
						|
                "The size of tensor a {} must match the size of tensor b ("
 | 
						|
                "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
 | 
						|
            )
 | 
						|
    return _copy(a)
 | 
						|
 | 
						|
 | 
						|
def expand(self: List[int], sizes: List[int]):
 | 
						|
    assert len(sizes) >= len(self)
 | 
						|
    ndim = len(sizes)
 | 
						|
    tensor_dim = len(self)
 | 
						|
    if ndim == 0:
 | 
						|
        return _copy(sizes)
 | 
						|
    out: List[int] = []
 | 
						|
    for i in range(ndim):
 | 
						|
        offset = ndim - 1 - i
 | 
						|
        dim = tensor_dim - 1 - offset
 | 
						|
        size = self[dim] if dim >= 0 else 1
 | 
						|
        targetSize = sizes[i]
 | 
						|
        if targetSize == -1:
 | 
						|
            assert dim >= 0
 | 
						|
            targetSize = size
 | 
						|
        if size != targetSize:
 | 
						|
            assert size == 1
 | 
						|
            size = targetSize
 | 
						|
        out.append(size)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def expand_one_unused(self: List[int], sizes: List[int], inp0: Any):
 | 
						|
    return expand(self, sizes)
 | 
						|
 | 
						|
 | 
						|
def infer_size_impl(shape: List[int], numel: int) -> List[int]:
 | 
						|
    newsize = 1
 | 
						|
    infer_dim: Optional[int] = None
 | 
						|
    for dim in range(len(shape)):
 | 
						|
        if shape[dim] == -1:
 | 
						|
            if infer_dim is not None:
 | 
						|
                raise AssertionError("only one dimension can be inferred")
 | 
						|
            infer_dim = dim
 | 
						|
        elif shape[dim] >= 0:
 | 
						|
            newsize *= shape[dim]
 | 
						|
        else:
 | 
						|
            raise AssertionError("invalid shape dimensions")
 | 
						|
    if not (
 | 
						|
        numel == newsize
 | 
						|
        or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
 | 
						|
    ):
 | 
						|
        raise AssertionError("invalid shape")
 | 
						|
    out = _copy(shape)
 | 
						|
    if infer_dim is not None:
 | 
						|
        out[infer_dim] = numel // newsize
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def numel(sizes: List[int]):
 | 
						|
    numel = 1
 | 
						|
    for elem in sizes:
 | 
						|
        numel *= elem
 | 
						|
    return numel
 | 
						|
 | 
						|
 | 
						|
def view(self: List[int], sizes: List[int]):
 | 
						|
    return infer_size_impl(sizes, numel(self))
 | 
						|
 | 
						|
 | 
						|
def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False):
 | 
						|
    return view(self, sizes)
 | 
						|
 | 
						|
 | 
						|
def sum_mean_dim(
 | 
						|
    self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any
 | 
						|
):
 | 
						|
    out: List[int] = []
 | 
						|
    if opt_dims is None or len(opt_dims) == 0:
 | 
						|
        dims: List[int] = list(range(len(self)))
 | 
						|
    else:
 | 
						|
        dims = opt_dims
 | 
						|
 | 
						|
    for idx in range(len(self)):
 | 
						|
        is_mean_dim: bool = False
 | 
						|
        for reduce_dim in dims:
 | 
						|
            if idx == maybe_wrap_dim(reduce_dim, len(self)):
 | 
						|
                is_mean_dim = True
 | 
						|
        if is_mean_dim:
 | 
						|
            if keep_dim:
 | 
						|
                out.append(1)
 | 
						|
        else:
 | 
						|
            out.append(self[idx])
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def max_dim(self: List[int], dim: int, keep_dim: bool):
 | 
						|
    out = sum_mean_dim(self, [dim], keep_dim, None)
 | 
						|
    return out, out
 | 
						|
 | 
						|
 | 
						|
# note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
 | 
						|
def div_rtn(x: int, y: int):
 | 
						|
    return x // y
 | 
						|
 | 
						|
 | 
						|
def pooling_output_shape_pad_lr(
 | 
						|
    inputSize: int,
 | 
						|
    kernelSize: int,
 | 
						|
    pad_l: int,
 | 
						|
    pad_r: int,
 | 
						|
    stride: int,
 | 
						|
    dilation: int,
 | 
						|
    ceil_mode: bool,
 | 
						|
):
 | 
						|
    outputSize = (
 | 
						|
        div_rtn(
 | 
						|
            inputSize
 | 
						|
            + pad_l
 | 
						|
            + pad_r
 | 
						|
            - dilation * (kernelSize - 1)
 | 
						|
            - 1
 | 
						|
            + (stride - 1 if ceil_mode else 0),
 | 
						|
            stride,
 | 
						|
        )
 | 
						|
        + 1
 | 
						|
    )
 | 
						|
    if ceil_mode:
 | 
						|
        if (outputSize - 1) * stride >= inputSize + pad_l:
 | 
						|
            outputSize = outputSize - 1
 | 
						|
    return outputSize
 | 
						|
 | 
						|
 | 
						|
def pooling_output_shape(
 | 
						|
    inputSize: int,
 | 
						|
    kernelSize: int,
 | 
						|
    pad_l: int,
 | 
						|
    stride: int,
 | 
						|
    dilation: int,
 | 
						|
    ceil_mode: bool,
 | 
						|
):
 | 
						|
    assert stride != 0, "stride should not be zeero"
 | 
						|
    return pooling_output_shape_pad_lr(
 | 
						|
        inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def pool2d_shape_check(
 | 
						|
    input: List[int],
 | 
						|
    kH: int,
 | 
						|
    kW: int,
 | 
						|
    dH: int,
 | 
						|
    dW: int,
 | 
						|
    padH: int,
 | 
						|
    padW: int,
 | 
						|
    dilationH: int,
 | 
						|
    dilationW: int,
 | 
						|
    nInputPlane: int,
 | 
						|
    inputHeight: int,
 | 
						|
    inputWidth: int,
 | 
						|
    outputHeight: int,
 | 
						|
    outputWidth: int,
 | 
						|
):
 | 
						|
    ndim = len(input)
 | 
						|
    nOutputPlane = nInputPlane
 | 
						|
 | 
						|
    assert kW > 0 and kH > 0
 | 
						|
    assert dW > 0 and dH > 0
 | 
						|
    assert dilationH > 0 and dilationW > 0
 | 
						|
 | 
						|
    valid_dims = input[1] != 0 and input[2] != 0
 | 
						|
    assert (
 | 
						|
        ndim == 3
 | 
						|
        and input[0] != 0
 | 
						|
        and valid_dims
 | 
						|
        or (ndim == 4 and valid_dims and input[3] != 0)
 | 
						|
    )
 | 
						|
 | 
						|
    assert kW // 2 >= padW and kH // 2 >= padH
 | 
						|
    assert outputWidth >= 1 and outputHeight >= 1
 | 
						|
 | 
						|
 | 
						|
def max_pool2d(
 | 
						|
    input: List[int],
 | 
						|
    kernel_size: List[int],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    ceil_mode: bool,
 | 
						|
):
 | 
						|
    assert (
 | 
						|
        len(kernel_size) == 1 or len(kernel_size) == 2
 | 
						|
    ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
 | 
						|
    kH = kernel_size[0]
 | 
						|
    kW = kH if len(kernel_size) == 1 else kernel_size[1]
 | 
						|
 | 
						|
    assert (
 | 
						|
        len(stride) == 0 or len(stride) == 1 or len(stride) == 2
 | 
						|
    ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
 | 
						|
    dH = kH if len(stride) == 0 else stride[0]
 | 
						|
    if len(stride) == 0:
 | 
						|
        dW = kW
 | 
						|
    elif len(stride) == 1:
 | 
						|
        dW = dH
 | 
						|
    else:
 | 
						|
        dW = stride[1]
 | 
						|
 | 
						|
    assert (
 | 
						|
        len(padding) == 1 or len(padding) == 2
 | 
						|
    ), "max_pool2d: padding must either be a single int, or a tuple of two ints"
 | 
						|
    padH = padding[0]
 | 
						|
    padW = padH if len(padding) == 1 else padding[1]
 | 
						|
 | 
						|
    assert (
 | 
						|
        len(dilation) == 1 or len(dilation) == 2
 | 
						|
    ), "max_pool2d: dilation must be either a single int, or a tuple of two ints"
 | 
						|
    dilationH = dilation[0]
 | 
						|
    dilationW = dilationH if len(dilation) == 1 else dilation[1]
 | 
						|
 | 
						|
    assert len(input) == 3 or len(input) == 4
 | 
						|
 | 
						|
    nbatch = input[-4] if len(input) == 4 else 1
 | 
						|
    nInputPlane = input[-3]
 | 
						|
    inputHeight = input[-2]
 | 
						|
    inputWidth = input[-1]
 | 
						|
 | 
						|
    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
 | 
						|
    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
 | 
						|
 | 
						|
    pool2d_shape_check(
 | 
						|
        input,
 | 
						|
        kH,
 | 
						|
        kW,
 | 
						|
        dH,
 | 
						|
        dW,
 | 
						|
        padH,
 | 
						|
        padW,
 | 
						|
        dilationH,
 | 
						|
        dilationW,
 | 
						|
        nInputPlane,
 | 
						|
        inputHeight,
 | 
						|
        inputWidth,
 | 
						|
        outputHeight,
 | 
						|
        outputWidth,
 | 
						|
    )
 | 
						|
 | 
						|
    if len(input) == 3:
 | 
						|
        return [nInputPlane, outputHeight, outputWidth]
 | 
						|
    else:
 | 
						|
        return [nbatch, nInputPlane, outputHeight, outputWidth]
 | 
						|
 | 
						|
 | 
						|
def max_pool2d_with_indices(
 | 
						|
    input: List[int],
 | 
						|
    kernel_size: List[int],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    ceil_mode: bool,
 | 
						|
):
 | 
						|
    out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
 | 
						|
    return (out, out)
 | 
						|
 | 
						|
 | 
						|
def upsample_nearest2d(
 | 
						|
    input: List[int],
 | 
						|
    output_size: Optional[List[int]],
 | 
						|
    scale_factors: Optional[List[float]],
 | 
						|
):
 | 
						|
    out: List[int] = []
 | 
						|
    out.append(input[0])
 | 
						|
    out.append(input[1])
 | 
						|
 | 
						|
    if scale_factors is None and output_size is None:
 | 
						|
        assert 0, "Either output_size or scale_factors must be presented"
 | 
						|
 | 
						|
    if output_size is not None:
 | 
						|
        assert (
 | 
						|
            scale_factors is None
 | 
						|
        ), "Must specify exactly one of output_size and scale_factors"
 | 
						|
        assert len(output_size) == 2
 | 
						|
        out.append(output_size[0])
 | 
						|
        out.append(output_size[1])
 | 
						|
 | 
						|
    if scale_factors is not None:
 | 
						|
        assert (
 | 
						|
            output_size is None
 | 
						|
        ), "Must specify exactly one of output_size and scale_factors"
 | 
						|
        assert len(scale_factors) == 2
 | 
						|
        out.append(int(input[2] * scale_factors[0]))
 | 
						|
        out.append(int(input[3] * scale_factors[1]))
 | 
						|
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def mm(self: List[int], mat2: List[int]):
 | 
						|
    assert len(self) == 2, "self must be a matrix"
 | 
						|
    assert len(mat2) == 2, "mat2 must be a matrix"
 | 
						|
 | 
						|
    assert self[1] == mat2[0]
 | 
						|
    return [self[0], mat2[1]]
 | 
						|
 | 
						|
 | 
						|
def dot(self: List[int], tensor: List[int]):
 | 
						|
    assert len(self) == 1 and len(tensor) == 1
 | 
						|
    assert self[0] == tensor[0]
 | 
						|
    out: List[int] = []
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def mv(self: List[int], vec: List[int]):
 | 
						|
    assert len(self) == 2 and len(vec) == 1
 | 
						|
    assert self[1] == vec[0]
 | 
						|
    # TODO: return self
 | 
						|
    return [self[0]]
 | 
						|
 | 
						|
 | 
						|
def unsqueeze(li: List[int], dim: int):
 | 
						|
    dim = maybe_wrap_dim(dim, len(li) + 1)
 | 
						|
    out = _copy(li)
 | 
						|
    out.insert(dim, 1)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def squeeze_nodim(li: List[int]):
 | 
						|
    out: List[int] = []
 | 
						|
    for i in range(len(li)):
 | 
						|
        if li[i] != 1:
 | 
						|
            out.append(li[i])
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def squeeze(li: List[int], dim: int):
 | 
						|
    out: List[int] = []
 | 
						|
    wrapped_dim = maybe_wrap_dim(dim, len(li))
 | 
						|
    for i in range(len(li)):
 | 
						|
        if i == wrapped_dim:
 | 
						|
            if li[i] != 1:
 | 
						|
                out.append(li[i])
 | 
						|
        else:
 | 
						|
            out.append(li[i])
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def squeeze_dims(li: List[int], dims: List[int]):
 | 
						|
    if len(dims) == 0:
 | 
						|
        return li
 | 
						|
    wrapped_dims = _copy(dims)
 | 
						|
    for i in range(len(dims)):
 | 
						|
        wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
 | 
						|
    result: List[int] = []
 | 
						|
    for i in range(len(li)):
 | 
						|
        if li[i] == 1:
 | 
						|
            if i not in wrapped_dims:
 | 
						|
                result.append(li[i])
 | 
						|
        else:
 | 
						|
            result.append(li[i])
 | 
						|
    return result
 | 
						|
 | 
						|
 | 
						|
def index_select(self: List[int], dim: int, index: List[int]):
 | 
						|
    dim = maybe_wrap_dim(dim, len(self))
 | 
						|
    numel = multiply_integers(index)
 | 
						|
    assert len(index) <= 1
 | 
						|
    assert dim == 0 or dim < len(self)
 | 
						|
    result_size: List[int] = []
 | 
						|
    for i in range(len(self)):
 | 
						|
        if dim == i:
 | 
						|
            result_size.append(numel)
 | 
						|
        else:
 | 
						|
            result_size.append(self[i])
 | 
						|
    return result_size
 | 
						|
 | 
						|
 | 
						|
def embedding(
 | 
						|
    weight: List[int],
 | 
						|
    indices: List[int],
 | 
						|
    padding_idx: int = -1,
 | 
						|
    scale_grad_by_freq: bool = False,
 | 
						|
    sparse: bool = False,
 | 
						|
):
 | 
						|
    assert len(weight) == 2
 | 
						|
    if len(indices) == 1:
 | 
						|
        return index_select(weight, 0, indices)
 | 
						|
    size = _copy(indices)
 | 
						|
    size.append(weight[1])
 | 
						|
    return size
 | 
						|
 | 
						|
 | 
						|
def max_int():
 | 
						|
    return 9223372036854775807
 | 
						|
 | 
						|
 | 
						|
def slice(
 | 
						|
    self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int
 | 
						|
):
 | 
						|
    ndim = len(self)
 | 
						|
    assert ndim != 0
 | 
						|
    dim = maybe_wrap_dim(dim, ndim)
 | 
						|
    start_val = start if start is not None else 0
 | 
						|
    end_val = end if end is not None else max_int()
 | 
						|
    assert step > 0
 | 
						|
    if start_val == max_int():
 | 
						|
        start_val = 0
 | 
						|
    if start_val < 0:
 | 
						|
        start_val += self[dim]
 | 
						|
    if end_val < 0:
 | 
						|
        end_val += self[dim]
 | 
						|
    if start_val < 0:
 | 
						|
        start_val = 0
 | 
						|
    elif start_val > self[dim]:
 | 
						|
        start_val = self[dim]
 | 
						|
    if end_val < start_val:
 | 
						|
        end_val = start_val
 | 
						|
    elif end_val >= self[dim]:
 | 
						|
        end_val = self[dim]
 | 
						|
    slice_len = end_val - start_val
 | 
						|
    out = _copy(self)
 | 
						|
    out[dim] = (slice_len + step - 1) // step
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def check_cat_no_zero_dim(tensors: List[List[int]]):
 | 
						|
    for tensor in tensors:
 | 
						|
        assert len(tensor) > 0
 | 
						|
 | 
						|
 | 
						|
def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
 | 
						|
    out_dim: Optional[int] = None
 | 
						|
    for size in tensor_sizes:
 | 
						|
        if not (len(size) == 1 and size[0] == 0):
 | 
						|
            if out_dim is None:
 | 
						|
                out_dim = maybe_wrap_dim(dim, len(size))
 | 
						|
    if out_dim is None:
 | 
						|
        out_dim = dim
 | 
						|
    return out_dim
 | 
						|
 | 
						|
 | 
						|
def should_skip(tensor: List[int]):
 | 
						|
    return numel(tensor) == 0 and len(tensor) == 1
 | 
						|
 | 
						|
 | 
						|
def check_cat_shape_except_dim(
 | 
						|
    first: List[int], second: List[int], dimension: int, index: int
 | 
						|
):
 | 
						|
    first_dims = len(first)
 | 
						|
    second_dims = len(second)
 | 
						|
    assert first_dims == second_dims, "Tensors must have same number of dimensions"
 | 
						|
    for dim in range(0, first_dims):
 | 
						|
        if dim != dimension:
 | 
						|
            assert (
 | 
						|
                first[dim] == second[dim]
 | 
						|
            ), "Sizes of tensors must match except in dimension"
 | 
						|
 | 
						|
 | 
						|
def cat(tensors: List[List[int]], dim: int):
 | 
						|
    check_cat_no_zero_dim(tensors)
 | 
						|
    dim = legacy_cat_wrap_dim(dim, tensors)
 | 
						|
    assert len(tensors) > 0
 | 
						|
    not_skipped_tensor: Optional[List[int]] = None
 | 
						|
    for tensor in tensors:
 | 
						|
        if not should_skip(tensor):
 | 
						|
            not_skipped_tensor = tensor
 | 
						|
    if not_skipped_tensor is None:
 | 
						|
        return [0]
 | 
						|
 | 
						|
    cat_dim_size = 0
 | 
						|
 | 
						|
    for i in range(len(tensors)):
 | 
						|
        tensor = tensors[i]
 | 
						|
        if not should_skip(tensor):
 | 
						|
            check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
 | 
						|
            cat_dim_size = cat_dim_size + tensor[dim]
 | 
						|
 | 
						|
    result_size = _copy(not_skipped_tensor)
 | 
						|
    result_size[dim] = cat_dim_size
 | 
						|
    return result_size
 | 
						|
 | 
						|
 | 
						|
def stack(tensors: List[List[int]], dim: int):
 | 
						|
    unsqueezed_tensors: List[List[int]] = []
 | 
						|
    for tensor in tensors:
 | 
						|
        unsqueezed = unsqueeze(tensor, dim)
 | 
						|
        unsqueezed_tensors.append(unsqueezed)
 | 
						|
    return cat(unsqueezed_tensors, dim)
 | 
						|
 | 
						|
 | 
						|
def select(self: List[int], dim: int, index: int):
 | 
						|
    ndim = len(self)
 | 
						|
    assert ndim != 0
 | 
						|
    dim = maybe_wrap_dim(dim, ndim)
 | 
						|
    size = self[dim]
 | 
						|
    assert not (index < -size or index >= size)
 | 
						|
    if index < 0:
 | 
						|
        index += size
 | 
						|
    out: List[int] = []
 | 
						|
    for i in range(ndim):
 | 
						|
        if i != dim:
 | 
						|
            out.append(self[i])
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def matmul(tensor1: List[int], tensor2: List[int]):
 | 
						|
    dim_tensor1 = len(tensor1)
 | 
						|
    dim_tensor2 = len(tensor2)
 | 
						|
    if dim_tensor1 == 1 and dim_tensor2 == 1:
 | 
						|
        return dot(tensor1, tensor2)
 | 
						|
    elif dim_tensor1 == 2 and dim_tensor2 == 1:
 | 
						|
        return mv(tensor1, tensor2)
 | 
						|
    elif dim_tensor1 == 1 and dim_tensor2 == 2:
 | 
						|
        return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
 | 
						|
    elif dim_tensor1 == 2 and dim_tensor2 == 2:
 | 
						|
        return mm(tensor1, tensor2)
 | 
						|
    elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
 | 
						|
        # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
 | 
						|
        # we track m1 vs m2 separately even though they must match for nicer error messages
 | 
						|
        n = tensor1[-2] if dim_tensor1 > 1 else 1
 | 
						|
        m1 = tensor1[-1]
 | 
						|
        batch_tensor1: List[int] = []
 | 
						|
        # TODO: handling of slice
 | 
						|
        for i in range(dim_tensor1 - 2):
 | 
						|
            batch_tensor1.append(tensor1[i])
 | 
						|
        m2 = tensor2[-1] if dim_tensor2 > 1 else 1
 | 
						|
        p = tensor2[-1]
 | 
						|
        batch_tensor2: List[int] = []
 | 
						|
        # TODO: handling of slice
 | 
						|
        for i in range(dim_tensor2 - 2):
 | 
						|
            batch_tensor2.append(tensor2[i])
 | 
						|
 | 
						|
        # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
 | 
						|
        expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
 | 
						|
 | 
						|
        # todo: copy ?
 | 
						|
        output_shape = expand_batch_portion
 | 
						|
        if dim_tensor1 > 1:
 | 
						|
            output_shape.append(n)
 | 
						|
 | 
						|
        if dim_tensor2 > 1:
 | 
						|
            output_shape.append(p)
 | 
						|
 | 
						|
        return output_shape
 | 
						|
    else:
 | 
						|
        assert False, "both  arguments to matmul need to be at least 1D"
 | 
						|
 | 
						|
 | 
						|
def t(self: List[int]):
 | 
						|
    assert len(self) <= 2
 | 
						|
    self_len = len(self)
 | 
						|
    if self_len == 0:
 | 
						|
        out: List[int] = []
 | 
						|
        return out
 | 
						|
    elif self_len == 1:
 | 
						|
        return [self[0]]
 | 
						|
    else:
 | 
						|
        return [self[1], self[0]]
 | 
						|
 | 
						|
 | 
						|
def transpose(self: List[int], dim0: int, dim1: int):
 | 
						|
    ndims = len(self)
 | 
						|
    dim0 = maybe_wrap_dim(dim0, ndims)
 | 
						|
    dim1 = maybe_wrap_dim(dim1, ndims)
 | 
						|
    if dim0 == dim1:
 | 
						|
        return _copy(self)
 | 
						|
    out: List[int] = []
 | 
						|
    for i in range(ndims):
 | 
						|
        if i == dim0:
 | 
						|
            out.append(self[dim1])
 | 
						|
        elif i == dim1:
 | 
						|
            out.append(self[dim0])
 | 
						|
        else:
 | 
						|
            out.append(self[i])
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
 | 
						|
    out = matmul(input, t(weight))
 | 
						|
    if bias is not None:
 | 
						|
        assert broadcast(bias, out) == out
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
 | 
						|
    return broadcast(self, mm(mat1, mat2))
 | 
						|
 | 
						|
 | 
						|
def check_non_negative(array: List[int]) -> bool:
 | 
						|
    # TODO: look into rewriting with early return and getting loop unrolling to fire
 | 
						|
    non_negative = False
 | 
						|
    for val in array:
 | 
						|
        if val < 0:
 | 
						|
            non_negative = True
 | 
						|
    return non_negative
 | 
						|
 | 
						|
 | 
						|
def check_shape_forward(
 | 
						|
    input: List[int],
 | 
						|
    weight_sizes: List[int],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    groups: int,
 | 
						|
):
 | 
						|
    k = len(input)
 | 
						|
    weight_dim = len(weight_sizes)
 | 
						|
 | 
						|
    # TODO: assertions could be expanded with the error messages
 | 
						|
    assert not check_non_negative(padding)
 | 
						|
    assert not check_non_negative(stride)
 | 
						|
 | 
						|
    assert weight_dim == k
 | 
						|
    assert weight_sizes[0] >= groups
 | 
						|
    assert (weight_sizes[0] % groups) == 0
 | 
						|
    # only handling not transposed
 | 
						|
    assert input[1] == weight_sizes[1] * groups
 | 
						|
    assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0])
 | 
						|
 | 
						|
    for i in range(2, k):
 | 
						|
        assert (input[i] + 2 * padding[i - 2]) >= (
 | 
						|
            dilation[i - 2] * (weight_sizes[i] - 1) + 1
 | 
						|
        )
 | 
						|
 | 
						|
    # this is not handling transposed convolution yet
 | 
						|
 | 
						|
 | 
						|
def conv_output_size(
 | 
						|
    input_size: List[int],
 | 
						|
    weight_size: List[int],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    groups: int,
 | 
						|
):
 | 
						|
    check_shape_forward(
 | 
						|
        input_size, weight_size, bias, stride, padding, dilation, groups
 | 
						|
    )
 | 
						|
 | 
						|
    has_dilation = len(dilation) > 0
 | 
						|
    dim = len(input_size)
 | 
						|
    output_size: List[int] = []
 | 
						|
    input_batch_size_dim = 0
 | 
						|
    weight_output_channels_dim = 0
 | 
						|
    output_size.append(input_size[input_batch_size_dim])
 | 
						|
    output_size.append(weight_size[weight_output_channels_dim])
 | 
						|
 | 
						|
    for d in range(2, dim):
 | 
						|
        dilation_ = dilation[d - 2] if has_dilation else 1
 | 
						|
        kernel = dilation_ * (weight_size[d] - 1) + 1
 | 
						|
        output_size.append(
 | 
						|
            (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
 | 
						|
        )
 | 
						|
    return output_size
 | 
						|
 | 
						|
 | 
						|
def conv1d(
 | 
						|
    input: List[int],
 | 
						|
    weight: List[int],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    groups: int,
 | 
						|
):
 | 
						|
    assert len(weight) == 3
 | 
						|
    assert len(input) == 3
 | 
						|
    return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
 | 
						|
 | 
						|
 | 
						|
def conv2d(
 | 
						|
    input: List[int],
 | 
						|
    weight: List[int],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    groups: int,
 | 
						|
):
 | 
						|
    assert len(weight) == 4
 | 
						|
    assert len(input) == 4
 | 
						|
    return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
 | 
						|
 | 
						|
 | 
						|
def conv_backwards(
 | 
						|
    grad_output: List[int],
 | 
						|
    input: List[int],
 | 
						|
    weight: List[int],
 | 
						|
    biases: Optional[List[int]],
 | 
						|
):
 | 
						|
    # Bias gradient is always generated regardess of if biases is supplied
 | 
						|
    return _copy(input), _copy(weight), [grad_output[1]]
 | 
						|
 | 
						|
 | 
						|
def conv_transpose2d_input(
 | 
						|
    input: List[int],
 | 
						|
    weight: List[int],
 | 
						|
    bias: Optional[List[int]] = None,
 | 
						|
    stride: Optional[List[int]] = None,
 | 
						|
    padding: Optional[List[int]] = None,
 | 
						|
    output_padding: Optional[List[int]] = None,
 | 
						|
    groups: int = 1,
 | 
						|
    dilation: Optional[List[int]] = None,
 | 
						|
) -> List[int]:
 | 
						|
    if stride is None:
 | 
						|
        stride = [1, 1]
 | 
						|
    if padding is None:
 | 
						|
        padding = [0, 0]
 | 
						|
    if output_padding is None:
 | 
						|
        output_padding = [0, 0]
 | 
						|
    if dilation is None:
 | 
						|
        dilation = [1, 1]
 | 
						|
    has_dilation = len(dilation) > 0
 | 
						|
    dim = len(input)
 | 
						|
    output_size: List[int] = []
 | 
						|
    input_batch_size_dim = 0
 | 
						|
    weight_output_channels_dim = 1
 | 
						|
    output_size.append(input[input_batch_size_dim])
 | 
						|
    output_size.append(weight[weight_output_channels_dim] * groups)
 | 
						|
 | 
						|
    for d in range(2, dim):
 | 
						|
        dilation_ = dilation[d - 2] if has_dilation else 1
 | 
						|
        kernel = dilation_ * (weight[d] - 1)
 | 
						|
        output_size.append(
 | 
						|
            (input[d] - 1) * stride[d - 2]
 | 
						|
            - 2 * padding[d - 2]
 | 
						|
            + kernel
 | 
						|
            + output_padding[d - 2]
 | 
						|
            + 1
 | 
						|
        )
 | 
						|
    return output_size
 | 
						|
 | 
						|
 | 
						|
def conv_forwards(
 | 
						|
    input: List[int],
 | 
						|
    weight: List[int],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    transposed: bool,
 | 
						|
    output_padding: List[int],
 | 
						|
    groups: int,
 | 
						|
) -> List[int]:
 | 
						|
    has_dilation = len(dilation) > 0
 | 
						|
    has_output_padding = len(output_padding) > 0
 | 
						|
    dim = len(input)
 | 
						|
    output_size: List[int] = []
 | 
						|
    input_batch_size_dim = 0
 | 
						|
    weight_output_channels_dim = 1 if transposed else 0
 | 
						|
    output_size.append(input[input_batch_size_dim])
 | 
						|
    if transposed:
 | 
						|
        output_size.append(weight[weight_output_channels_dim] * groups)
 | 
						|
    else:
 | 
						|
        output_size.append(weight[weight_output_channels_dim])
 | 
						|
 | 
						|
    for d in range(2, dim):
 | 
						|
        dilation_ = dilation[d - 2] if has_dilation else 1
 | 
						|
        output_padding_ = output_padding[d - 2] if has_output_padding else 0
 | 
						|
        if transposed:
 | 
						|
            kernel = dilation_ * (weight[d] - 1)
 | 
						|
            output_size.append(
 | 
						|
                (input[d] - 1) * stride[d - 2]
 | 
						|
                - 2 * padding[d - 2]
 | 
						|
                + kernel
 | 
						|
                + output_padding_
 | 
						|
                + 1
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            kernel = dilation_ * (weight[d] - 1) + 1
 | 
						|
            output_size.append(
 | 
						|
                (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
 | 
						|
            )
 | 
						|
    return output_size
 | 
						|
 | 
						|
 | 
						|
def _conv_forwards(
 | 
						|
    input: List[int],
 | 
						|
    weight: List[int],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    transposed: bool,
 | 
						|
    output_padding: List[int],
 | 
						|
    groups: int,
 | 
						|
    benchmark: bool,
 | 
						|
    deterministic: bool,
 | 
						|
    cudnn_enabled: bool,
 | 
						|
    allow_tf32: bool,
 | 
						|
) -> List[int]:
 | 
						|
    return conv_forwards(
 | 
						|
        input,
 | 
						|
        weight,
 | 
						|
        bias,
 | 
						|
        stride,
 | 
						|
        padding,
 | 
						|
        dilation,
 | 
						|
        transposed,
 | 
						|
        output_padding,
 | 
						|
        groups,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def batch_norm(
 | 
						|
    input: List[int],
 | 
						|
    weight: Optional[List[int]],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    running_mean: Optional[List[int]],
 | 
						|
    running_var: Optional[List[int]],
 | 
						|
    training: bool,
 | 
						|
    momentum: float,
 | 
						|
    eps: float,
 | 
						|
    cudnn_enabled: bool,
 | 
						|
):
 | 
						|
    out: List[int] = []
 | 
						|
    for elem in input:
 | 
						|
        out.append(elem)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def conv3d(
 | 
						|
    input: List[int],
 | 
						|
    weight: List[int],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    stride: List[int],
 | 
						|
    padding: List[int],
 | 
						|
    dilation: List[int],
 | 
						|
    groups: int,
 | 
						|
):
 | 
						|
    assert len(weight) == 5
 | 
						|
    assert len(input) == 5
 | 
						|
    return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
 | 
						|
 | 
						|
 | 
						|
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
 | 
						|
    if dim_post_expr <= 0:
 | 
						|
        assert wrap_scalar
 | 
						|
        dim_post_expr = 1
 | 
						|
    min = -dim_post_expr
 | 
						|
    max = dim_post_expr - 1
 | 
						|
    assert not (dim < min or dim > max)
 | 
						|
    if dim < 0:
 | 
						|
        dim += dim_post_expr
 | 
						|
    return dim
 | 
						|
 | 
						|
 | 
						|
def zero_dim_tensor(input: Any):
 | 
						|
    out: List[int] = []
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def multiply_integers(li: List[int]):
 | 
						|
    out = 1
 | 
						|
    for elem in li:
 | 
						|
        out = out * elem
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
 | 
						|
    assert end >= 0
 | 
						|
    return [int(math.ceil(end))]
 | 
						|
 | 
						|
 | 
						|
def arange_start(
 | 
						|
    start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
 | 
						|
):
 | 
						|
    assert end >= 0
 | 
						|
    assert end >= start
 | 
						|
    return [int(math.ceil(end - start))]
 | 
						|
 | 
						|
 | 
						|
def arange_start_step(
 | 
						|
    start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
 | 
						|
):
 | 
						|
    assert step != 0
 | 
						|
    if step < 0:
 | 
						|
        assert start >= end
 | 
						|
    else:
 | 
						|
        assert end >= start
 | 
						|
    return [int(math.ceil((end - start) / step))]
 | 
						|
 | 
						|
 | 
						|
def permute(input: List[int], dims: List[int]):
 | 
						|
    assert len(input) == len(dims)
 | 
						|
    ndim = len(dims)
 | 
						|
    seen_dims: List[int] = []
 | 
						|
    newSizes: List[int] = []
 | 
						|
    for i in range(ndim):
 | 
						|
        dim = maybe_wrap_dim(dims[i], ndim)
 | 
						|
        seen_dims.append(dim)
 | 
						|
        newSizes.append(input[dim])
 | 
						|
    for i in range(1, ndim):
 | 
						|
        for j in range(i):
 | 
						|
            assert seen_dims[i] != seen_dims[j]
 | 
						|
    return newSizes
 | 
						|
 | 
						|
 | 
						|
def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]:
 | 
						|
    self_dim = len(self)
 | 
						|
    if self_dim <= 1:
 | 
						|
        return self
 | 
						|
    normalized_src: List[int] = []
 | 
						|
    normalized_dst: List[int] = []
 | 
						|
    for i in range(len(source)):
 | 
						|
        normalized_src.append(maybe_wrap_dim(source[i], self_dim))
 | 
						|
        normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
 | 
						|
    order = [-1 for i in range(self_dim)]
 | 
						|
    src_dims = [i for i in range(self_dim)]
 | 
						|
    dst_dims = [i for i in range(self_dim)]
 | 
						|
 | 
						|
    for i in range(len(source)):
 | 
						|
        order[normalized_dst[i]] = normalized_src[i]
 | 
						|
        src_dims[normalized_src[i]] = -1
 | 
						|
        dst_dims[normalized_dst[i]] = -1
 | 
						|
 | 
						|
    source_dims: List[int] = []
 | 
						|
    destination_dims: List[int] = []
 | 
						|
    for ele in src_dims:
 | 
						|
        if ele != -1:
 | 
						|
            source_dims.append(ele)
 | 
						|
    for ele in dst_dims:
 | 
						|
        if ele != -1:
 | 
						|
            destination_dims.append(ele)
 | 
						|
 | 
						|
    rest_dim = self_dim - len(source)
 | 
						|
    for i in range(rest_dim):
 | 
						|
        order[destination_dims[i]] = source_dims[i]
 | 
						|
    return permute(self, order)
 | 
						|
 | 
						|
 | 
						|
def flatten(input: List[int], start_dim: int, end_dim: int):
 | 
						|
    start_dim = maybe_wrap_dim(start_dim, len(input))
 | 
						|
    end_dim = maybe_wrap_dim(end_dim, len(input))
 | 
						|
    assert start_dim <= end_dim
 | 
						|
    if len(input) == 0:
 | 
						|
        return [1]
 | 
						|
    if start_dim == end_dim:
 | 
						|
        # TODO: return self
 | 
						|
        out: List[int] = []
 | 
						|
        for elem in input:
 | 
						|
            out.append(elem)
 | 
						|
        return out
 | 
						|
    slice_numel = 1
 | 
						|
    for i in range(start_dim, end_dim + 1):
 | 
						|
        slice_numel *= input[i]
 | 
						|
    # TODO: use slicing when slice optimization has landed
 | 
						|
    # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
 | 
						|
    shape: List[int] = []
 | 
						|
    for i in range(start_dim):
 | 
						|
        shape.append(input[i])
 | 
						|
    shape.append(slice_numel)
 | 
						|
    for i in range(end_dim + 1, len(input)):
 | 
						|
        shape.append(input[i])
 | 
						|
    return shape
 | 
						|
 | 
						|
 | 
						|
def nonzero_lower_bound(input: List[int]):
 | 
						|
    return [0, len(input)]
 | 
						|
 | 
						|
 | 
						|
def nonzero_upper_bound(input: List[int]):
 | 
						|
    return [numel(input), len(input)]
 | 
						|
 | 
						|
 | 
						|
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
 | 
						|
    dim = maybe_wrap_dim(dim, len(self))
 | 
						|
    out: List[int] = []
 | 
						|
    for i, self_dim in enumerate(self):
 | 
						|
        if i == dim:
 | 
						|
            if keepdim:
 | 
						|
                out.append(1)
 | 
						|
        else:
 | 
						|
            out.append(self_dim)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def argmax(
 | 
						|
    self: List[int], dim: Optional[int] = None, keepdim: bool = False
 | 
						|
) -> List[int]:
 | 
						|
    if dim is None:
 | 
						|
        return []
 | 
						|
    return _reduce_along_dim(self, dim, keepdim)
 | 
						|
 | 
						|
 | 
						|
def bmm(self: List[int], mat2: List[int]) -> List[int]:
 | 
						|
    assert len(self) == 3, "bmm only supports 3D tensors"
 | 
						|
    assert len(mat2) == 3, "bmm only supports 3D tensors"
 | 
						|
    assert self[0] == mat2[0], "mismatching batch dimension"
 | 
						|
    assert self[2] == mat2[1], "mismatching contracting dimension"
 | 
						|
    return [self[0], self[1], mat2[2]]
 | 
						|
 | 
						|
 | 
						|
def _shape_as_tensor(self: List[int]) -> List[int]:
 | 
						|
    return [len(self)]
 | 
						|
 | 
						|
 | 
						|
def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
 | 
						|
    if len(self) == 0:
 | 
						|
        result: List[int] = []
 | 
						|
    else:
 | 
						|
        assert (
 | 
						|
            k <= self[dim]
 | 
						|
        ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
 | 
						|
        result = _copy(self)
 | 
						|
        result[dim] = k
 | 
						|
    return result, result
 | 
						|
 | 
						|
 | 
						|
def nll_loss_forward(
 | 
						|
    self: List[int], target: List[int], weight: Optional[List[int]], reduction: int
 | 
						|
) -> Tuple[List[int], List[int]]:
 | 
						|
    # This is taken shamelessly from the meta function in LossNLL.cpp
 | 
						|
    self_dim = len(self)
 | 
						|
    target_dim = len(target)
 | 
						|
    assert 0 < self_dim <= 2
 | 
						|
    assert target_dim <= 1
 | 
						|
    no_batch_dim = self_dim == 1 and target_dim == 0
 | 
						|
    assert no_batch_dim or (self[0] == target[0])
 | 
						|
    n_classes = self[-1]
 | 
						|
    scalar_shape: List[int] = []
 | 
						|
    assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
 | 
						|
    if reduction == 0 and self_dim == 2:
 | 
						|
        reduction_shape = [self[0]]
 | 
						|
    else:
 | 
						|
        reduction_shape = scalar_shape
 | 
						|
    return reduction_shape, scalar_shape
 | 
						|
 | 
						|
 | 
						|
def native_layer_norm(
 | 
						|
    input: List[int], normalized_shape: List[int]
 | 
						|
) -> Tuple[List[int], List[int], List[int]]:
 | 
						|
    reduction_shape: List[int] = []
 | 
						|
    num_unreduced_dimensions = len(input) - len(normalized_shape)
 | 
						|
    assert num_unreduced_dimensions >= 0
 | 
						|
    for i in range(num_unreduced_dimensions):
 | 
						|
        reduction_shape.append(input[i])
 | 
						|
    for i in range(num_unreduced_dimensions, len(input)):
 | 
						|
        reduction_shape.append(1)
 | 
						|
    return _copy(input), reduction_shape, reduction_shape
 | 
						|
 | 
						|
 | 
						|
def native_batch_norm(
 | 
						|
    input: List[int],
 | 
						|
    weight: Optional[List[int]],
 | 
						|
    bias: Optional[List[int]],
 | 
						|
    running_mean: Optional[List[int]],
 | 
						|
    running_var: Optional[List[int]],
 | 
						|
    training: bool,
 | 
						|
) -> Tuple[List[int], List[int], List[int]]:
 | 
						|
    if training:
 | 
						|
        _size = [input[1]]
 | 
						|
    else:
 | 
						|
        _size = [0]
 | 
						|
    return _copy(input), _size, _size
 | 
						|
 | 
						|
 | 
						|
def cross_entropy_loss(
 | 
						|
    self: List[int],
 | 
						|
    target: List[int],
 | 
						|
    weight: Optional[List[int]] = None,
 | 
						|
    reduction: int = 1,
 | 
						|
    ignore_index: int = -100,
 | 
						|
    label_smoothing: float = 0.0,
 | 
						|
) -> List[int]:
 | 
						|
    result_shape = nll_loss_forward(self, target, weight, reduction)[0]
 | 
						|
    return result_shape
 | 
						|
 | 
						|
 | 
						|
"""
 | 
						|
Currently deferring the enabling of this, as part of the propoasal to suspend
 | 
						|
adding ops.
 | 
						|
There are currently cases in the test case where this is being called
 | 
						|
in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
 | 
						|
opinfo test). The behavoir of index is significantly dependent on the inputs.
 | 
						|
 | 
						|
This could be an error with how we are matching up shape functions, or that this
 | 
						|
function needs to just implement everything.
 | 
						|
 | 
						|
def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
 | 
						|
    assert len(indices) <= len(self), "More indices than dimensions to index"
 | 
						|
    broadcasted_shape: List[int] = []
 | 
						|
    for index_tensor_shape in indices:
 | 
						|
        if index_tensor_shape is not None:
 | 
						|
            broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
 | 
						|
    return broadcasted_shape
 | 
						|
"""
 | 
						|
 | 
						|
ScriptFn = torch._C.ScriptFunction
 | 
						|
shape_compute_graph_mapping: Dict[str, ScriptFn] = {}
 | 
						|
bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {}
 | 
						|
script_func_map: Dict[Callable, ScriptFn] = {}
 | 
						|
 | 
						|
 | 
						|
def process_func(func: Callable):
 | 
						|
    if func not in script_func_map:
 | 
						|
        scripted_func = torch.jit.script(func)
 | 
						|
 | 
						|
        torch._C._jit_pass_inline(scripted_func.graph)
 | 
						|
 | 
						|
        for _ in range(2):
 | 
						|
            torch._C._jit_pass_peephole(scripted_func.graph)
 | 
						|
            torch._C._jit_pass_constant_propagation(scripted_func.graph)
 | 
						|
 | 
						|
        script_func_map[func] = scripted_func
 | 
						|
    return script_func_map[func]
 | 
						|
 | 
						|
 | 
						|
def add_shape_compute_mapping(operator_schema: str, func: Callable):
 | 
						|
    global shape_compute_graph_mapping
 | 
						|
 | 
						|
    shape_compute_graph_mapping[operator_schema] = process_func(func)
 | 
						|
 | 
						|
 | 
						|
def add_bounded_compute_mapping(
 | 
						|
    operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable
 | 
						|
):
 | 
						|
    # Adds a shape compute function for both upper and lower bounds
 | 
						|
    fns = (process_func(lower_bound_func), process_func(upper_bound_func))
 | 
						|
    bounded_compute_graph_mapping[operator_schema] = fns
 | 
						|
 | 
						|
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor",
 | 
						|
    adaptive_avg_pool2d,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor
 | 
						|
)
 | 
						|
add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
 | 
						|
    arange_end,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
 | 
						|
    arange_start,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
 | 
						|
    arange_start_step,
 | 
						|
)
 | 
						|
add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
 | 
						|
    slice,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
 | 
						|
    "float eps=1e-05, bool cudnn_enable=True) -> Tensor",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
 | 
						|
    embedding,
 | 
						|
)
 | 
						|
add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
 | 
						|
add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
 | 
						|
add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
 | 
						|
add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
 | 
						|
    max_pool2d,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)",
 | 
						|
    max_pool2d_with_indices,
 | 
						|
)
 | 
						|
add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor",
 | 
						|
    conv1d,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
 | 
						|
    conv2d,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
 | 
						|
    batch_norm,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor",
 | 
						|
    conv3d,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)",
 | 
						|
    conv_backwards,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
 | 
						|
    conv_forwards,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
 | 
						|
    _conv_forwards,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor",
 | 
						|
    conv_transpose2d_input,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)",
 | 
						|
    flatten,
 | 
						|
)
 | 
						|
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
 | 
						|
add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)",
 | 
						|
    movedim,
 | 
						|
)
 | 
						|
add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)",
 | 
						|
    expand_one_unused,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
 | 
						|
    sum_mean_dim,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
 | 
						|
    sum_mean_dim,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
 | 
						|
    max_dim,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
 | 
						|
    addmm,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)",
 | 
						|
    upsample_nearest2d,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
 | 
						|
    unary,
 | 
						|
)
 | 
						|
add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
 | 
						|
    broadcast,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax
 | 
						|
)
 | 
						|
add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
 | 
						|
    topk,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)",
 | 
						|
    nll_loss_forward,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
 | 
						|
    native_layer_norm,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
 | 
						|
    native_batch_norm,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
 | 
						|
    native_batch_norm,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
 | 
						|
    native_batch_norm,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
 | 
						|
    cross_entropy_loss,
 | 
						|
)
 | 
						|
# add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
 | 
						|
 | 
						|
# TODO: migrate over all of symbolic_shape_registry_util.cpp
 | 
						|
# These are duplicated here so that the functions will be serialiazed
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor",
 | 
						|
    broadcast_three,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
 | 
						|
    broadcast_one_three,
 | 
						|
)
 | 
						|
add_shape_compute_mapping(
 | 
						|
    "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)",
 | 
						|
    broadcast_inplace,
 | 
						|
)
 | 
						|
 | 
						|
# quantized_conv_prepack TODO
 | 
						|
 | 
						|
# Shape Compute Fn with upper and lower bounds
 | 
						|
add_bounded_compute_mapping(
 | 
						|
    "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound
 | 
						|
)
 |