mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Audit for error prone isinstance int/float and add lint (#87345)
We recently fixed a bug on symbolic-shapes branch where an isinstance(x, int) test failed when passed a SymIntNode. To prevent this, I've added a lint for all the codepaths where we may pass SymInt/SymFloat directly to reject direct isinstance int/float tests, and instead use one of the aliases. The lint rule explains the options. I then go and fix all of them. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/87345 Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
1285542f9b
commit
d73d4aa7de
@ -420,6 +420,35 @@ command = [
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'ERROR_PRONE_ISINSTANCE'
|
||||
include_patterns = [
|
||||
'torch/_refs/**/*.py',
|
||||
'torch/_prims/**/*.py',
|
||||
'torch/_prims_common/**/*.py',
|
||||
'torch/_decomp/**/*.py',
|
||||
'torch/_meta_registrations.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/grep_linter.py',
|
||||
'--pattern=isinstance\([^)]+(int|float)\)',
|
||||
'--linter-name=ERROR_PRONE_ISINSTANCE',
|
||||
'--error-name=error prone isinstance',
|
||||
"""--error-description=\
|
||||
This line has an isinstance call that directly refers to \
|
||||
int or float. This is error-prone because you may also \
|
||||
have wanted to allow SymIntNode or SymFloatNode in your test. \
|
||||
To suppress this lint, use an appropriate type alias defined \
|
||||
in torch._prims_common; use IntLike/FloatLike when you would accept \
|
||||
both regular and symbolic numbers, Dim for ints representing \
|
||||
dimensions, or IntWithoutSymInt/FloatWithoutSymFloat if you really \
|
||||
meant to exclude symbolic numbers.
|
||||
""",
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'PYBIND11_SPECIALIZATION'
|
||||
include_patterns = [
|
||||
|
||||
@ -181,6 +181,8 @@ class SymFloatNode(object):
|
||||
@staticmethod
|
||||
def new_symfloat(obj) -> SymFloatNode: ...
|
||||
|
||||
def __ceil__(self) -> SymIntNode: ...
|
||||
|
||||
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
|
||||
class MobileOptimizerType:
|
||||
...
|
||||
|
||||
@ -11,7 +11,7 @@ import torch._prims_common as utils
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch._decomp import register_decomposition
|
||||
from torch._prims_common import NumberType, TensorLike, TensorSequenceType
|
||||
from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType
|
||||
from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
@ -1740,7 +1740,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
|
||||
return torch.mean(vals, dim=(-3, -1))
|
||||
|
||||
def maybe_mask(vals, length, range_max, adaptive, dim):
|
||||
if isinstance(length, int):
|
||||
if isinstance(length, IntLike):
|
||||
return vals, length
|
||||
else:
|
||||
# zero-out the things we didn't really want to select
|
||||
|
||||
@ -11,6 +11,8 @@ from torch._prims_common import (
|
||||
corresponding_real_dtype,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
FloatLike,
|
||||
IntLike,
|
||||
)
|
||||
|
||||
from torch._prims_common.wrappers import out_wrapper
|
||||
@ -361,24 +363,24 @@ def meta_conv(
|
||||
output_padding: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
ret_shape = []
|
||||
if isinstance(stride, int):
|
||||
if isinstance(stride, IntLike):
|
||||
stride = [stride] * len(dims)
|
||||
elif len(stride) == 1:
|
||||
stride = [stride[0]] * len(dims)
|
||||
|
||||
if isinstance(padding, int):
|
||||
if isinstance(padding, IntLike):
|
||||
padding = [padding] * len(dims)
|
||||
elif len(padding) == 1:
|
||||
padding = [padding[0]] * len(dims)
|
||||
|
||||
if isinstance(dilation, int):
|
||||
if isinstance(dilation, IntLike):
|
||||
dilation = [dilation] * len(dims)
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
|
||||
output_padding_list: Optional[List[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, int):
|
||||
if isinstance(output_padding, IntLike):
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
elif len(output_padding) == 1:
|
||||
output_padding_list = [output_padding[0]] * len(dims)
|
||||
@ -1393,11 +1395,11 @@ def meta_like(self, *args, **kwargs):
|
||||
# hacky: Please remove after math.ceil works with arange
|
||||
@register_meta(aten.arange.default)
|
||||
def arange(end, **kwargs):
|
||||
if isinstance(end, float):
|
||||
end = math.ceil(end)
|
||||
if isinstance(end, FloatLike):
|
||||
end = math.ceil(end) # type: ignore[arg-type]
|
||||
|
||||
def is_integral(x):
|
||||
return isinstance(x, int) or isinstance(x, bool)
|
||||
return isinstance(x, IntLike) or isinstance(x, bool)
|
||||
|
||||
set_to_integral_dtype = kwargs.get("dtype", None) is None and is_integral(end)
|
||||
if set_to_integral_dtype:
|
||||
|
||||
@ -16,8 +16,10 @@ from torch._C import _get_default_device
|
||||
from torch._prims.nvfuser_prims import register_nvprims
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
DimsType,
|
||||
IntLike,
|
||||
Number,
|
||||
NumberType,
|
||||
RETURN_TYPE,
|
||||
@ -929,7 +931,7 @@ bitwise_xor = _make_elementwise_binary_prim(
|
||||
# div prim performs truncation division on integer inputs
|
||||
# and true division for floating and complex inputs
|
||||
def _div_aten(a, b):
|
||||
is_integral = isinstance(a, (bool, int)) or (
|
||||
is_integral = isinstance(a, (bool, int, torch.SymIntNode)) or (
|
||||
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
|
||||
)
|
||||
|
||||
@ -1198,7 +1200,7 @@ def _broadcast_in_dim_meta(
|
||||
# (no relative reordering of dims) of integers and
|
||||
# each dimension must be within the new shape
|
||||
def _greater_than_reduce(acc, x):
|
||||
assert isinstance(x, int)
|
||||
assert isinstance(x, Dim)
|
||||
assert x > acc
|
||||
assert x < len(shape)
|
||||
|
||||
@ -2319,7 +2321,7 @@ def _arange_meta(
|
||||
)
|
||||
if dtype is not None:
|
||||
pass
|
||||
elif all(isinstance(arg, int) for arg in (start, end, step)):
|
||||
elif all(isinstance(arg, IntLike) for arg in (start, end, step)):
|
||||
dtype = torch.int64
|
||||
else:
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
@ -47,7 +47,15 @@ NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
|
||||
# TODO: This needs a lot more type annotations
|
||||
# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode]
|
||||
NumberType = Union[bool, int, float, complex]
|
||||
|
||||
Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode)
|
||||
# I don't call it Integral because numbers.Integral includes bool, but IntLike
|
||||
# does not
|
||||
Dim = int
|
||||
IntLike = (int, torch.SymIntNode)
|
||||
FloatLike = (float, torch.SymFloatNode)
|
||||
IntWithoutSymInt = int
|
||||
FloatWithoutSymFloat = float
|
||||
DeviceLikeType = Union[str, torch.device]
|
||||
Tensor = torch.Tensor
|
||||
|
||||
@ -433,8 +441,8 @@ def validate_idx(rank: int, idx: int):
|
||||
Assumes the index is already canonicalized.
|
||||
"""
|
||||
|
||||
assert isinstance(idx, int)
|
||||
assert isinstance(rank, int)
|
||||
assert isinstance(idx, Dim)
|
||||
assert isinstance(rank, Dim)
|
||||
|
||||
assert idx >= 0 and idx < rank or idx == 0
|
||||
|
||||
@ -450,8 +458,8 @@ def validate_exclusive_idx(rank: int, ex_idx: int):
|
||||
for the given shape.
|
||||
"""
|
||||
|
||||
assert isinstance(ex_idx, int)
|
||||
assert isinstance(rank, int)
|
||||
assert isinstance(ex_idx, Dim)
|
||||
assert isinstance(rank, Dim)
|
||||
assert ex_idx > 0 and ex_idx <= rank
|
||||
|
||||
|
||||
@ -500,7 +508,7 @@ def canonicalize_dims(rank: int, indices: int) -> int:
|
||||
|
||||
|
||||
def canonicalize_dims(rank, indices):
|
||||
if isinstance(indices, int):
|
||||
if isinstance(indices, Dim):
|
||||
return canonicalize_dim(rank, indices)
|
||||
|
||||
return tuple(canonicalize_dim(rank, x) for x in indices)
|
||||
@ -1439,7 +1447,8 @@ def set_correction(
|
||||
correction = 1
|
||||
elif correction is None and unbiased is not None:
|
||||
correction = 0 if unbiased is False else 1
|
||||
if not isinstance(correction, int):
|
||||
# NB: we don't actually support symint here, but it's harmless to accept
|
||||
if not isinstance(correction, IntLike):
|
||||
raise ValueError("correction argument should be integer")
|
||||
if correction < 0:
|
||||
raise ValueError("correction argument should be non-negative")
|
||||
|
||||
@ -16,10 +16,13 @@ import torch._prims_common as utils
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
DeviceLikeType,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
DimsType,
|
||||
dtype_to_type,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
FloatLike,
|
||||
IntLike,
|
||||
is_weakly_lesser_type,
|
||||
Number,
|
||||
NumberType,
|
||||
@ -39,6 +42,7 @@ from torch._prims_common.wrappers import (
|
||||
elementwise_unary_scalar_wrapper,
|
||||
out_wrapper,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import sym_float, sym_int
|
||||
|
||||
# Experimental module containing prototype Python references for existing
|
||||
# PyTorch operations.
|
||||
@ -298,7 +302,7 @@ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
||||
|
||||
def _broadcast_shapes(*_shapes):
|
||||
shapes = tuple(
|
||||
(x,) if isinstance(x, int) else x
|
||||
(x,) if isinstance(x, IntLike) else x
|
||||
for x in filter(lambda x: x is not None, _shapes)
|
||||
)
|
||||
|
||||
@ -1939,8 +1943,8 @@ def _reduction(
|
||||
"dtype argument and out dtype must match in reduction"
|
||||
)
|
||||
if not accepts_dim_tuple:
|
||||
assert dims is None or isinstance(dims, int)
|
||||
if isinstance(dims, int):
|
||||
assert dims is None or isinstance(dims, Dim)
|
||||
if isinstance(dims, Dim):
|
||||
dims = (dims,) # type: ignore[assignment]
|
||||
dims = utils.reduction_dims(a.shape, dims)
|
||||
if not has_identity:
|
||||
@ -1986,7 +1990,7 @@ def all(
|
||||
keepdim: bool = False,
|
||||
) -> TensorLikeType:
|
||||
# Computes nelem
|
||||
if isinstance(dim, int):
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
|
||||
a_ = _maybe_convert_to_dtype(a, torch.bool)
|
||||
@ -2246,7 +2250,7 @@ def mean(
|
||||
)
|
||||
if utils.is_integer_dtype(dtype):
|
||||
raise RuntimeError("result type should be floating point or complex")
|
||||
if isinstance(dim, int):
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
|
||||
nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
|
||||
@ -3299,7 +3303,7 @@ def tensor_split(
|
||||
raise ValueError(msg)
|
||||
|
||||
# Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
|
||||
if isinstance(indices_or_sections, int) or (
|
||||
if isinstance(indices_or_sections, IntLike) or (
|
||||
isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
|
||||
):
|
||||
sections: int = (
|
||||
@ -3365,7 +3369,7 @@ def hsplit(
|
||||
),
|
||||
)
|
||||
dim = 0 if a.ndim == 1 else 1
|
||||
if isinstance(indices_or_sections, int):
|
||||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
check(
|
||||
(split_size != 0 and a.shape[dim] % split_size == 0),
|
||||
@ -3407,7 +3411,7 @@ def vsplit(
|
||||
+ " dimensions!"
|
||||
),
|
||||
)
|
||||
if isinstance(indices_or_sections, int):
|
||||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
check(
|
||||
(split_size != 0 and a.shape[0] % split_size == 0),
|
||||
@ -3538,7 +3542,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
|
||||
raise RuntimeError(
|
||||
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
|
||||
)
|
||||
if isinstance(sections, int) and (sections == 0 or a.shape[2] % sections != 0):
|
||||
if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
|
||||
raise RuntimeError(
|
||||
"torch._refs.dsplit attempted to split along dimension 2, "
|
||||
+ f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
|
||||
@ -3983,21 +3987,21 @@ def linspace(
|
||||
# cast than not, because it allows us to always go into the precise path
|
||||
# if dtype is integral and not worry about whether start/end are float
|
||||
if prims.utils.is_integer_dtype(dtype):
|
||||
if isinstance(start, float):
|
||||
start = int(start)
|
||||
if isinstance(end, float):
|
||||
end = int(end)
|
||||
if isinstance(start, FloatLike):
|
||||
start = sym_int(start)
|
||||
if isinstance(end, FloatLike):
|
||||
end = sym_int(end)
|
||||
|
||||
if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
|
||||
raise NotImplementedError
|
||||
assert not isinstance(start, complex) and not isinstance(end, complex) # for mypy
|
||||
|
||||
check(
|
||||
isinstance(steps, int),
|
||||
isinstance(steps, IntLike),
|
||||
lambda: "steps must be int, not float",
|
||||
exc_type=TypeError,
|
||||
)
|
||||
assert isinstance(steps, int) # for mypy
|
||||
assert isinstance(steps, IntLike) # for mypy
|
||||
check(steps >= 0, lambda: "number of steps must be non-negative")
|
||||
|
||||
factory_kwargs = {
|
||||
@ -4016,7 +4020,7 @@ def linspace(
|
||||
if prims.utils.is_integer_dtype(dtype):
|
||||
# We need to cast to int, so to avoid off-by-one issues
|
||||
# do the entire computation with ints when we can
|
||||
assert isinstance(start, int) and isinstance(end, int)
|
||||
assert isinstance(start, IntLike) and isinstance(end, IntLike)
|
||||
step_size_x_denom = end - start
|
||||
eps = 1 if end > start else -1
|
||||
denom = steps - 1
|
||||
@ -4063,10 +4067,10 @@ def logspace(
|
||||
|
||||
# NB: NumPy doesn't have this cast
|
||||
if prims.utils.is_integer_dtype(dtype):
|
||||
if isinstance(start, float):
|
||||
start = int(start)
|
||||
if isinstance(end, float):
|
||||
end = int(end)
|
||||
if isinstance(start, FloatLike):
|
||||
start = sym_int(start)
|
||||
if isinstance(end, FloatLike):
|
||||
end = sym_int(end)
|
||||
|
||||
assert not isinstance(base, complex) # for mypy
|
||||
if base < 0:
|
||||
@ -4402,10 +4406,10 @@ def uniform(
|
||||
) -> TensorLikeType:
|
||||
utils.validate_shape(shape)
|
||||
|
||||
assert isinstance(low, (bool, int, float))
|
||||
assert isinstance(high, (bool, int, float))
|
||||
low = float(low)
|
||||
high = float(high)
|
||||
assert isinstance(low, Number)
|
||||
assert isinstance(high, Number)
|
||||
low = sym_float(low)
|
||||
high = sym_float(high)
|
||||
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
device = utils.canonicalize_device(device)
|
||||
@ -4505,10 +4509,10 @@ def norm(
|
||||
) -> TensorLikeType:
|
||||
# In these cases we compute the "Frobenius norm"
|
||||
if (
|
||||
p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2)
|
||||
p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
|
||||
) or p is None:
|
||||
p = 2
|
||||
if isinstance(dim, int):
|
||||
if isinstance(dim, Dim):
|
||||
dim = [dim]
|
||||
if isinstance(p, str):
|
||||
# Here we either call the nuclear norm, or we call matrix_norm with some arguments
|
||||
|
||||
@ -14,6 +14,7 @@ from torch._prims_common import (
|
||||
check,
|
||||
check_fp_or_complex,
|
||||
check_is_matrix,
|
||||
Dim,
|
||||
DimsType,
|
||||
NumberType,
|
||||
TensorLikeType,
|
||||
@ -69,7 +70,7 @@ def vector_norm(
|
||||
# Checks
|
||||
check_fp_or_complex(x.dtype, "linalg.vector_norm")
|
||||
|
||||
if isinstance(dim, int):
|
||||
if isinstance(dim, Dim):
|
||||
dim = [dim] # type: ignore[assignment]
|
||||
elif not isinstance(dim, List) and dim is not None:
|
||||
# refs.amin just accepts List rather than DimType (Tuple)
|
||||
@ -142,7 +143,7 @@ def matrix_norm(
|
||||
check_is_matrix(A, "linalg.matrix_norm")
|
||||
# dim
|
||||
dim = utils.canonicalize_dims(A.ndim, dim)
|
||||
if isinstance(dim, int):
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
|
||||
check(
|
||||
@ -219,7 +220,7 @@ def norm(
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> TensorLikeType:
|
||||
if dim is not None:
|
||||
if isinstance(dim, int):
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
check(
|
||||
len(dim) in (1, 2),
|
||||
|
||||
Reference in New Issue
Block a user