Audit for error prone isinstance int/float and add lint (#87345)

We recently fixed a bug on symbolic-shapes branch where
an isinstance(x, int) test failed when passed a SymIntNode.
To prevent this, I've added a lint for all the codepaths
where we may pass SymInt/SymFloat directly to reject
direct isinstance int/float tests, and instead use one of
the aliases.  The lint rule explains the options.  I then
go and fix all of them.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87345
Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-10-21 05:54:15 -07:00
committed by PyTorch MergeBot
parent 1285542f9b
commit d73d4aa7de
8 changed files with 96 additions and 47 deletions

View File

@ -420,6 +420,35 @@ command = [
'@{{PATHSFILE}}'
]
[[linter]]
code = 'ERROR_PRONE_ISINSTANCE'
include_patterns = [
'torch/_refs/**/*.py',
'torch/_prims/**/*.py',
'torch/_prims_common/**/*.py',
'torch/_decomp/**/*.py',
'torch/_meta_registrations.py',
]
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
'--pattern=isinstance\([^)]+(int|float)\)',
'--linter-name=ERROR_PRONE_ISINSTANCE',
'--error-name=error prone isinstance',
"""--error-description=\
This line has an isinstance call that directly refers to \
int or float. This is error-prone because you may also \
have wanted to allow SymIntNode or SymFloatNode in your test. \
To suppress this lint, use an appropriate type alias defined \
in torch._prims_common; use IntLike/FloatLike when you would accept \
both regular and symbolic numbers, Dim for ints representing \
dimensions, or IntWithoutSymInt/FloatWithoutSymFloat if you really \
meant to exclude symbolic numbers.
""",
'--',
'@{{PATHSFILE}}'
]
[[linter]]
code = 'PYBIND11_SPECIALIZATION'
include_patterns = [

View File

@ -181,6 +181,8 @@ class SymFloatNode(object):
@staticmethod
def new_symfloat(obj) -> SymFloatNode: ...
def __ceil__(self) -> SymIntNode: ...
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
class MobileOptimizerType:
...

View File

@ -11,7 +11,7 @@ import torch._prims_common as utils
import torch.nn.functional as F
from torch import Tensor
from torch._decomp import register_decomposition
from torch._prims_common import NumberType, TensorLike, TensorSequenceType
from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType
from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper
from torch.utils._pytree import tree_flatten, tree_map
@ -1740,7 +1740,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
return torch.mean(vals, dim=(-3, -1))
def maybe_mask(vals, length, range_max, adaptive, dim):
if isinstance(length, int):
if isinstance(length, IntLike):
return vals, length
else:
# zero-out the things we didn't really want to select

View File

@ -11,6 +11,8 @@ from torch._prims_common import (
corresponding_real_dtype,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
FloatLike,
IntLike,
)
from torch._prims_common.wrappers import out_wrapper
@ -361,24 +363,24 @@ def meta_conv(
output_padding: Optional[Union[List[int], int]] = None,
):
ret_shape = []
if isinstance(stride, int):
if isinstance(stride, IntLike):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, int):
if isinstance(padding, IntLike):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
if isinstance(dilation, int):
if isinstance(dilation, IntLike):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, int):
if isinstance(output_padding, IntLike):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
@ -1393,11 +1395,11 @@ def meta_like(self, *args, **kwargs):
# hacky: Please remove after math.ceil works with arange
@register_meta(aten.arange.default)
def arange(end, **kwargs):
if isinstance(end, float):
end = math.ceil(end)
if isinstance(end, FloatLike):
end = math.ceil(end) # type: ignore[arg-type]
def is_integral(x):
return isinstance(x, int) or isinstance(x, bool)
return isinstance(x, IntLike) or isinstance(x, bool)
set_to_integral_dtype = kwargs.get("dtype", None) is None and is_integral(end)
if set_to_integral_dtype:

View File

@ -16,8 +16,10 @@ from torch._C import _get_default_device
from torch._prims.nvfuser_prims import register_nvprims
from torch._prims_common import (
check,
Dim,
DimsSequenceType,
DimsType,
IntLike,
Number,
NumberType,
RETURN_TYPE,
@ -929,7 +931,7 @@ bitwise_xor = _make_elementwise_binary_prim(
# div prim performs truncation division on integer inputs
# and true division for floating and complex inputs
def _div_aten(a, b):
is_integral = isinstance(a, (bool, int)) or (
is_integral = isinstance(a, (bool, int, torch.SymIntNode)) or (
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
)
@ -1198,7 +1200,7 @@ def _broadcast_in_dim_meta(
# (no relative reordering of dims) of integers and
# each dimension must be within the new shape
def _greater_than_reduce(acc, x):
assert isinstance(x, int)
assert isinstance(x, Dim)
assert x > acc
assert x < len(shape)
@ -2319,7 +2321,7 @@ def _arange_meta(
)
if dtype is not None:
pass
elif all(isinstance(arg, int) for arg in (start, end, step)):
elif all(isinstance(arg, IntLike) for arg in (start, end, step)):
dtype = torch.int64
else:
dtype = torch.get_default_dtype()

View File

@ -47,7 +47,15 @@ NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
# TODO: This needs a lot more type annotations
# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode]
NumberType = Union[bool, int, float, complex]
Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode)
# I don't call it Integral because numbers.Integral includes bool, but IntLike
# does not
Dim = int
IntLike = (int, torch.SymIntNode)
FloatLike = (float, torch.SymFloatNode)
IntWithoutSymInt = int
FloatWithoutSymFloat = float
DeviceLikeType = Union[str, torch.device]
Tensor = torch.Tensor
@ -433,8 +441,8 @@ def validate_idx(rank: int, idx: int):
Assumes the index is already canonicalized.
"""
assert isinstance(idx, int)
assert isinstance(rank, int)
assert isinstance(idx, Dim)
assert isinstance(rank, Dim)
assert idx >= 0 and idx < rank or idx == 0
@ -450,8 +458,8 @@ def validate_exclusive_idx(rank: int, ex_idx: int):
for the given shape.
"""
assert isinstance(ex_idx, int)
assert isinstance(rank, int)
assert isinstance(ex_idx, Dim)
assert isinstance(rank, Dim)
assert ex_idx > 0 and ex_idx <= rank
@ -500,7 +508,7 @@ def canonicalize_dims(rank: int, indices: int) -> int:
def canonicalize_dims(rank, indices):
if isinstance(indices, int):
if isinstance(indices, Dim):
return canonicalize_dim(rank, indices)
return tuple(canonicalize_dim(rank, x) for x in indices)
@ -1439,7 +1447,8 @@ def set_correction(
correction = 1
elif correction is None and unbiased is not None:
correction = 0 if unbiased is False else 1
if not isinstance(correction, int):
# NB: we don't actually support symint here, but it's harmless to accept
if not isinstance(correction, IntLike):
raise ValueError("correction argument should be integer")
if correction < 0:
raise ValueError("correction argument should be non-negative")

View File

@ -16,10 +16,13 @@ import torch._prims_common as utils
from torch._prims_common import (
check,
DeviceLikeType,
Dim,
DimsSequenceType,
DimsType,
dtype_to_type,
ELEMENTWISE_TYPE_PROMOTION_KIND,
FloatLike,
IntLike,
is_weakly_lesser_type,
Number,
NumberType,
@ -39,6 +42,7 @@ from torch._prims_common.wrappers import (
elementwise_unary_scalar_wrapper,
out_wrapper,
)
from torch.fx.experimental.symbolic_shapes import sym_float, sym_int
# Experimental module containing prototype Python references for existing
# PyTorch operations.
@ -298,7 +302,7 @@ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
def _broadcast_shapes(*_shapes):
shapes = tuple(
(x,) if isinstance(x, int) else x
(x,) if isinstance(x, IntLike) else x
for x in filter(lambda x: x is not None, _shapes)
)
@ -1939,8 +1943,8 @@ def _reduction(
"dtype argument and out dtype must match in reduction"
)
if not accepts_dim_tuple:
assert dims is None or isinstance(dims, int)
if isinstance(dims, int):
assert dims is None or isinstance(dims, Dim)
if isinstance(dims, Dim):
dims = (dims,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dims)
if not has_identity:
@ -1986,7 +1990,7 @@ def all(
keepdim: bool = False,
) -> TensorLikeType:
# Computes nelem
if isinstance(dim, int):
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
a_ = _maybe_convert_to_dtype(a, torch.bool)
@ -2246,7 +2250,7 @@ def mean(
)
if utils.is_integer_dtype(dtype):
raise RuntimeError("result type should be floating point or complex")
if isinstance(dim, int):
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
@ -3299,7 +3303,7 @@ def tensor_split(
raise ValueError(msg)
# Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
if isinstance(indices_or_sections, int) or (
if isinstance(indices_or_sections, IntLike) or (
isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
):
sections: int = (
@ -3365,7 +3369,7 @@ def hsplit(
),
)
dim = 0 if a.ndim == 1 else 1
if isinstance(indices_or_sections, int):
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
check(
(split_size != 0 and a.shape[dim] % split_size == 0),
@ -3407,7 +3411,7 @@ def vsplit(
+ " dimensions!"
),
)
if isinstance(indices_or_sections, int):
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
check(
(split_size != 0 and a.shape[0] % split_size == 0),
@ -3538,7 +3542,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
raise RuntimeError(
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
)
if isinstance(sections, int) and (sections == 0 or a.shape[2] % sections != 0):
if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
raise RuntimeError(
"torch._refs.dsplit attempted to split along dimension 2, "
+ f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
@ -3983,21 +3987,21 @@ def linspace(
# cast than not, because it allows us to always go into the precise path
# if dtype is integral and not worry about whether start/end are float
if prims.utils.is_integer_dtype(dtype):
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
end = int(end)
if isinstance(start, FloatLike):
start = sym_int(start)
if isinstance(end, FloatLike):
end = sym_int(end)
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
raise NotImplementedError
assert not isinstance(start, complex) and not isinstance(end, complex) # for mypy
check(
isinstance(steps, int),
isinstance(steps, IntLike),
lambda: "steps must be int, not float",
exc_type=TypeError,
)
assert isinstance(steps, int) # for mypy
assert isinstance(steps, IntLike) # for mypy
check(steps >= 0, lambda: "number of steps must be non-negative")
factory_kwargs = {
@ -4016,7 +4020,7 @@ def linspace(
if prims.utils.is_integer_dtype(dtype):
# We need to cast to int, so to avoid off-by-one issues
# do the entire computation with ints when we can
assert isinstance(start, int) and isinstance(end, int)
assert isinstance(start, IntLike) and isinstance(end, IntLike)
step_size_x_denom = end - start
eps = 1 if end > start else -1
denom = steps - 1
@ -4063,10 +4067,10 @@ def logspace(
# NB: NumPy doesn't have this cast
if prims.utils.is_integer_dtype(dtype):
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
end = int(end)
if isinstance(start, FloatLike):
start = sym_int(start)
if isinstance(end, FloatLike):
end = sym_int(end)
assert not isinstance(base, complex) # for mypy
if base < 0:
@ -4402,10 +4406,10 @@ def uniform(
) -> TensorLikeType:
utils.validate_shape(shape)
assert isinstance(low, (bool, int, float))
assert isinstance(high, (bool, int, float))
low = float(low)
high = float(high)
assert isinstance(low, Number)
assert isinstance(high, Number)
low = sym_float(low)
high = sym_float(high)
assert isinstance(dtype, torch.dtype)
device = utils.canonicalize_device(device)
@ -4505,10 +4509,10 @@ def norm(
) -> TensorLikeType:
# In these cases we compute the "Frobenius norm"
if (
p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2)
p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
) or p is None:
p = 2
if isinstance(dim, int):
if isinstance(dim, Dim):
dim = [dim]
if isinstance(p, str):
# Here we either call the nuclear norm, or we call matrix_norm with some arguments

View File

@ -14,6 +14,7 @@ from torch._prims_common import (
check,
check_fp_or_complex,
check_is_matrix,
Dim,
DimsType,
NumberType,
TensorLikeType,
@ -69,7 +70,7 @@ def vector_norm(
# Checks
check_fp_or_complex(x.dtype, "linalg.vector_norm")
if isinstance(dim, int):
if isinstance(dim, Dim):
dim = [dim] # type: ignore[assignment]
elif not isinstance(dim, List) and dim is not None:
# refs.amin just accepts List rather than DimType (Tuple)
@ -142,7 +143,7 @@ def matrix_norm(
check_is_matrix(A, "linalg.matrix_norm")
# dim
dim = utils.canonicalize_dims(A.ndim, dim)
if isinstance(dim, int):
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
check(
@ -219,7 +220,7 @@ def norm(
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
if dim is not None:
if isinstance(dim, int):
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
check(
len(dim) in (1, 2),