[BE] Enable flake8-simplify checks (#97984)

Enable some sensible flake8-simplify rules. Mainly wanted to enable the SIM101, and `yield from` SIM103 checks. @kit1980 since you wanted to be tagged on this CI check.

Enabling this check also helped flag one logical bug so it's definitely beneficial (also fixed in this PR).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97984
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2023-03-31 03:40:21 +00:00
committed by PyTorch MergeBot
parent 3dc4405278
commit 9c3fbe7475
9 changed files with 29 additions and 30 deletions

View File

@ -1,6 +1,6 @@
[flake8]
enable-extensions = G
select = B,C,E,F,G,P,T4,W,B9
select = B,C,E,F,G,P,SIM1,T4,W,B9
max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
@ -16,7 +16,11 @@ ignore =
# these ignores are from flake8-comprehensions; please fix!
C407
# these ignores are from flake8-logging-format; please fix!
G001,G002,G003,G004,G100,G101,G200,G201,G202
G001,G002,G003,G004,G100,G101,G200,G201,G202,
# these ignores are from flake8-simplify. please fix or ignore with commented reason
SIM105,SIM108,SIM109,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
# flake8-simplify code styles
SIM102,SIM103,SIM106,SIM112,
per-file-ignores =
__init__.py: F401
torch/utils/cpp_extension.py: B950

View File

@ -39,6 +39,7 @@ init_command = [
'flake8-executable==2.1.3',
'flake8-logging-format==0.9.0',
'flake8-pyi==23.3.1',
'flake8-simplify==0.19.3',
'mccabe==0.7.0',
'pycodestyle==2.10.0',
'pyflakes==3.0.1',

View File

@ -457,8 +457,8 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
if self.kwdefaults:
flags |= 0x02
codegen(self.kwdefaults)
if isinstance(self.annotations, variables.ConstDictVariable) or isinstance(
self.annotations, variables.TupleVariable
if isinstance(
self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
):
flags |= 0x04
try:

View File

@ -200,9 +200,7 @@ class ContextWrappingVariable(VariableTracker):
assert len(args) == 1
if isinstance(args[0], NestedUserFunctionVariable):
args[0] = UserFunctionVariable(args[0].get_function())
assert isinstance(args[0], UserMethodVariable) or isinstance(
args[0], UserFunctionVariable
)
assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
if isinstance(args[0], UserMethodVariable):
return WrappedUserMethodVariable(args[0], self)

View File

@ -4568,7 +4568,7 @@ def meshgrid(
# This ref simultaneously handles two overloads (see stubs above)
# The `indexing` argument is currently optional for torch.meshgrid, but we
# plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
if isinstance(tensors[0], list) or isinstance(tensors[0], tuple):
if isinstance(tensors[0], (list, tuple)):
assert len(tensors) == 1
tensors = tuple(tensors[0])

View File

@ -154,7 +154,7 @@ def _dtensor_expand(
if isinstance(a, torch.Tensor):
inps.append(a)
schemas.append(shard_schema)
elif isinstance(a, nn.Module) or isinstance(a, torch.optim.Optimizer):
elif isinstance(a, (nn.Module, torch.optim.Optimizer)):
# nn.Module or optimizer placeholder is captured by make_fx but
# never used in the graph
inps.append(torch.empty(0))

View File

@ -527,8 +527,8 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
@register_inference_rule(operator.gt)
def gt_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int)
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int)
assert isinstance(n.args[0], (Node, int))
assert isinstance(n.args[1], (Node, int))
# We make sure this node will not be used again. We do not
# generate a constraint about that node. Only about the operands.
@ -586,8 +586,8 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
@register_inference_rule(operator.eq)
def eq_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int)
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int)
assert isinstance(n.args[0], (Node, int))
assert isinstance(n.args[1], (Node, int))
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
@ -640,9 +640,9 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
# implementing for size 3 and 4
if len(n.args[1]) == 3:
assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int)
assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int)
assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int)
assert isinstance(n.args[1][0], (Node, int))
assert isinstance(n.args[1][1], (Node, int))
assert isinstance(n.args[1][2], (Node, int))
lhs = symbols[n.args[0]]
@ -674,10 +674,10 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
elif len(n.args[1]) == 4:
assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int)
assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int)
assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int)
assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int)
assert isinstance(n.args[1][0], (Node, int))
assert isinstance(n.args[1][1], (Node, int))
assert isinstance(n.args[1][2], (Node, int))
assert isinstance(n.args[1][3], (Node, int))
lhs = symbols[n.args[0]]
@ -722,8 +722,8 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
@register_inference_rule(operator.lt)
def lt_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node) or isinstance(n.args[0], int)
assert isinstance(n.args[1], Node) or isinstance(n.args[1], int)
assert isinstance(n.args[0], (Node, int))
assert isinstance(n.args[1], (Node, int))
# We make sure this node will not be used again. We do not
# generate a constraint about that node. Only about the operands.
@ -845,7 +845,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
else:
raise NotImplementedError('Method not yet implemented')
elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)):
elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
if isinstance(symbols[n.args[0]], TVar):
my_output, counter = gen_tvar(counter)
symbols[n] = my_output
@ -861,7 +861,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
BinConstraintD(0, my_output, op_leq)])
return [c], counter
elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)):
elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
if isinstance(symbols[n.args[1]], TVar):
my_output, counter = gen_tvar(counter)
symbols[n] = my_output

View File

@ -582,9 +582,7 @@ class MatMulDimInFP16Pattern(Pattern):
def source_code_location(event: Optional[_ProfilerEvent]):
while event:
if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
assert isinstance(event.extra_fields,
_ExtraFields_PyCall) or isinstance(
event.extra_fields, _ExtraFields_PyCCall)
assert isinstance(event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall))
if not event.extra_fields.caller.file_name.startswith("torch" +
os.sep):
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"

View File

@ -204,9 +204,7 @@ class GenLazyIR(ABC):
# as long as all of its arguments can be generated from information available from the schema
base_ctor_value_args_list = []
for arg in value_args:
if isinstance(arg.lazy_type, BaseCType) or isinstance(
arg.lazy_type, VectorCType
):
if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
base_ctor_value_args_list.append(f"{arg.name}")
elif isinstance(arg.lazy_type, OptionalCType):
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")