[BE] Enable more flake8-comprehensions checks (#94601)

I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR.

This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2023-02-10 23:40:26 +00:00
committed by PyTorch MergeBot
parent 0b31ebf9e4
commit 3d82d8d0ed
30 changed files with 71 additions and 82 deletions

View File

@ -11,7 +11,7 @@ ignore =
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
C400,C401,C402,C405,C407
per-file-ignores =
__init__.py: F401
torch/utils/cpp_extension.py: B950

View File

@ -181,7 +181,7 @@ class OperatorInputsMode(TorchDispatchMode):
return out
def log_to_file(self, output_filename, *, skip_non_compute_operators=True):
sorted_operators = sorted(list(self.func_db.keys()))
sorted_operators = sorted(self.func_db.keys())
with open(output_filename, "w") as f:
for operator in sorted_operators:
if skip_non_compute_operators and non_compute_operator(eval(operator)):

View File

@ -163,7 +163,7 @@ def tensortype_to_ndarray(tensor_type):
def generate_test_input_data(onnx_model, scale):
real_inputs_names = list(set([input.name for input in onnx_model.graph.input]) - set([init.name for init in onnx_model.graph.initializer]))
real_inputs_names = list({input.name for input in onnx_model.graph.input} - {init.name for init in onnx_model.graph.initializer})
real_inputs = []
for name in real_inputs_names:
for input in onnx_model.graph.input:

View File

@ -2297,7 +2297,7 @@ class DistributedDataParallelTest(
store=store,
)
seqs = ["sequence_sequence", "seq", "sequence"]
vocab = ["<pad>"] + sorted(set([ch for seq in seqs for ch in seq]))
vocab = ["<pad>"] + sorted({ch for seq in seqs for ch in seq})
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
# Set the seed to make the embedding and LSTM deterministic (even
# across ranks since DDP broadcasts parameters from rank 0)

View File

@ -426,7 +426,7 @@ def remove_torch(name):
def get_list_of_all_tests():
all_tests = list(tested_overridable_outplace_ops.keys())
return set([remove_torch(test) for test in all_tests])
return {remove_torch(test) for test in all_tests}
mytest = {
@ -459,11 +459,11 @@ def get_jvp_coverage(subset=None):
supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
if op_to_opinfo[fn][0].supports_forward_ad}
ops = set([remove_torch(test) for test in list(ops_dct.keys())])
supports_autograd = set([remove_torch(test)
for test in list(supports_autograd_ops_dct.keys())])
supports_forward_ad = set([remove_torch(test)
for test in list(supports_forwardad_ops_dct.keys())])
ops = {remove_torch(test) for test in list(ops_dct.keys())}
supports_autograd = {remove_torch(test)
for test in list(supports_autograd_ops_dct.keys())}
supports_forward_ad = {remove_torch(test)
for test in list(supports_forwardad_ops_dct.keys())}
assert supports_forward_ad.issubset(supports_autograd)
assert supports_autograd.issubset(ops)

View File

@ -169,12 +169,12 @@ class TestPythonKey(AOTTestCase):
return torch.tanh(x).sum()
fx_f = make_fx(grad(f))(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
ops = {i.target for i in fx_f.graph.nodes}
self.assertEqual(torch.ops.aten.tanh_backward in ops, True)
fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
ops = {i.target for i in fx_f.graph.nodes}
self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
def test_nnc_jit(self, device):

View File

@ -18,7 +18,7 @@ class TestMinifier(TestCase):
failing_f = make_fx(failing_f)(*inps)
def has_mul(fx_g, inps):
return (torch.ops.aten.mul.Tensor in set([i.target for i in fx_g.graph.nodes]))
return (torch.ops.aten.mul.Tensor in (i.target for i in fx_g.graph.nodes))
min_f, inps = minifier(failing_f, inps, has_mul)
self.assertEqual(len(min_f.graph.nodes), 4)
@ -74,7 +74,7 @@ class TestMinifier(TestCase):
inps = [torch.randn(3), torch.randn(3)]
def has_add(fx_g, inps):
return (torch.ops.aten.add.Tensor in set([i.target for i in fx_g.graph.nodes]))
return (torch.ops.aten.add.Tensor in (i.target for i in fx_g.graph.nodes))
failing_f = make_fx(f)(*inps)
min_f, inps = minifier(failing_f, inps, has_add)

View File

@ -114,7 +114,7 @@ def get_suggested_xfails(base, tests):
tests = [test[len(base):] for test in tests if
belongs_to_base(test, base)]
base_tests = set([remove_device_dtype(test) for test in tests])
base_tests = {remove_device_dtype(test) for test in tests}
tests = set(tests)
for base in base_tests:
cpu_variant = base + '_cpu_float32'

View File

@ -226,7 +226,7 @@ class TestList(JitTestCase):
self.checkScript(foo2, ())
def foo3():
return list(list("abc"))
return list(list("abc")) # noqa: C414
self.checkScript(foo3, ())
FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)

View File

@ -140,7 +140,7 @@ def calcOpsCoverage(ops):
"_coverage": round(coverage, 2),
"uncovered_ops": uncovered_ops_dict,
"covered_ops": covered_ops_dict,
"all_generated_ops": sorted(list(all_generated_ops)),
"all_generated_ops": sorted(all_generated_ops),
},
f,
)

View File

@ -40,7 +40,7 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
if hasattr(test_suite, "check_dtype"):
options.check_dtype = test_suite.check_dtype
names = set([f.name for f in dataclasses.fields(options)])
names = {f.name for f in dataclasses.fields(options)}
keywords_to_pop = []
for k, v in kwargs.items():
if k in names:

View File

@ -116,7 +116,7 @@ class TestDiGraph(PackageTestCase):
result = g.all_paths("1", "3")
# to get rid of indeterminism
actual = set([i.strip("\n") for i in result.split(";")[2:-1]])
actual = {i.strip("\n") for i in result.split(";")[2:-1]}
expected = {
'"2" -> "3"',
'"1" -> "7"',

View File

@ -365,10 +365,10 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
# test one line API - out of place version
base = AnnotatedSingleLayerLinearModel(qengine)
base.qconfig = qconfig
keys_before = set(list(base.state_dict().keys()))
keys_before = set(base.state_dict().keys())
model = quantize(base, test_only_eval_fn, [self.calib_data])
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
keys_after = set(base.state_dict().keys())
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
# in-place version
@ -1107,10 +1107,10 @@ class TestQuantizeEagerPTQDynamic(QuantizationTestCase):
# test one line API - out of place version
base = SingleLayerLinearDynamicModel()
keys_before = set(list(base.state_dict().keys()))
keys_before = set(base.state_dict().keys())
model = quantize_dynamic(base, qconfig_dict)
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
keys_after = set(base.state_dict().keys())
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
# in-place version

View File

@ -900,7 +900,7 @@ class TestFxModelReportClass(QuantizationTestCase):
model_report = ModelReport(model_prep, test_detector_set)
# make sure internal valid reports matches
detector_name_set = set([detector.get_detector_name() for detector in test_detector_set])
detector_name_set = {detector.get_detector_name() for detector in test_detector_set}
self.assertEqual(model_report.get_desired_reports_names(), detector_name_set)
# now attempt with no valid reports, should raise error
@ -1329,7 +1329,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
mods_to_check = set([nn.Linear, nn.Conv2d])
# get the set of all nodes in the graph their fqns
node_fqns = set([node.target for node in prepared_for_callibrate_model.graph.nodes])
node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}
# there should be 4 node fqns that have the observer inserted
correct_number_of_obs_inserted = 4

View File

@ -167,7 +167,7 @@ class TestNamedTupleAPI(TestCase):
ret3 = meth(*op.input)
check_namedtuple(ret3, op.names)
all_covered_operators = set([x for y in operators for x in y.operators])
all_covered_operators = {x for y in operators for x in y.operators}
self.assertEqual(all_operators_with_namedtuple_return, all_covered_operators, textwrap.dedent('''
The set of covered operators does not match the `all_operators_with_namedtuple_return` of

View File

@ -579,7 +579,7 @@ def forward(self, x_1):
gm = make_fx(Emformer())(torch.randn(16, 1, 256))
ops = set([n.target for n in gm.graph.nodes if n.op == 'call_function'])
ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
self.assertEqual(len(ops), 2)

View File

@ -264,7 +264,7 @@ class TestSparse(TestSparseBase):
else:
value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
new_indices = sorted(list(value_map.keys()))
new_indices = sorted(value_map.keys())
_new_values = [value_map[idx] for idx in new_indices]
if t._values().ndimension() < 2:
new_values = t._values().new(_new_values)

View File

@ -130,13 +130,11 @@ FILENAME_ALLOWLIST = {
}
# Include optimizer code for tracing
FILENAME_ALLOWLIST |= set(
[
inspect.getfile(obj)
for obj in torch.optim.__dict__.values()
if inspect.isclass(obj)
]
)
FILENAME_ALLOWLIST |= {
inspect.getfile(obj)
for obj in torch.optim.__dict__.values()
if inspect.isclass(obj)
}
FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
if HAS_PRIMS_REFS:

View File

@ -760,7 +760,7 @@ def enum_repr(value):
def dict_param_key_ids(value):
return set([id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)])
return {id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)}
def dict_const_keys(value):
@ -771,7 +771,7 @@ def dict_const_keys_repr(const_keys):
if any(isinstance(k, enum.Enum) for k in const_keys):
# To workaround repr(Enum) returning invalid global reference before python 3.11
# by calling enum_repr and removing quotes to render enum in guard code.
const_keys_str = f"{set([enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys])}".replace(
const_keys_str = f"{set(enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys)}".replace(
"'", ""
)
else:

View File

@ -304,17 +304,12 @@ class VariableBuilder:
else:
return key
result = dict(
[
(
k,
VariableBuilder(
self.tx, GetItemSource(self.get_source(), index_source(k))
)(value[k]).add_guards(guards),
)
for k in value.keys()
]
)
result = {
k: VariableBuilder(
self.tx, GetItemSource(self.get_source(), index_source(k))
)(value[k]).add_guards(guards)
for k in value.keys()
}
if istype(value, collections.defaultdict):
result = DefaultDictVariable(

View File

@ -393,7 +393,7 @@ def min_cut_rematerialization_partition(
for node in joint_module.graph.nodes
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
)
ops_ignored = joint_module_ops - set([str(i) for i in recomputable_ops])
ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops}
print("Ops banned from rematerialization: ", ops_ignored)
print()
@ -522,8 +522,8 @@ def min_cut_rematerialization_partition(
joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
if AOT_PARTITIONER_DEBUG:
print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9)
fw_module_nodes = set([node.name for node in fw_module.graph.nodes if node.op == 'call_function'])
bw_module_nodes = set([node.name for node in bw_module.graph.nodes if node.op == 'call_function'])
fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'}
bw_module_nodes = {node.name for node in bw_module.graph.nodes if node.op == 'call_function'}
remat_nodes = fw_module_nodes & bw_module_nodes
counts = defaultdict(int)

View File

@ -535,9 +535,7 @@ class GraphLowering(torch.fx.Interpreter):
writes = set(dep.name for dep in node.read_writes.writes)
def is_materialized(buf):
buf_uses = set(
[user.node for user in scheduler.name_to_node[buf].users]
)
buf_uses = {user.node for user in scheduler.name_to_node[buf].users}
return len(buf_uses - set(node.snodes)) > 0
if isinstance(node, FusedSchedulerNode):

View File

@ -344,7 +344,9 @@ def fresh_inductor_cache(cache_entries=None):
def argsort(seq):
# preserve original order for equal strides
return list(reversed(sorted(range(len(seq)), key=seq.__getitem__, reverse=True)))
getter = seq.__getitem__
a_r = range(len(seq))
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
@functools.lru_cache(8)

View File

@ -120,7 +120,7 @@ class ModelReport:
# keep the reports private so they can't be modified
self._desired_report_detectors = desired_report_detectors
self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors])
self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors}
# keep a mapping of desired reports to observers of interest
# this is to get the readings, and to remove them, can create a large set

View File

@ -1598,7 +1598,7 @@ def _all_gather_optim_state(
gathered_state: Dict[str, Any] = {}
all_tensor_states = sorted(
set([n for state in object_list for n in state.tensors.keys()])
{n for state in object_list for n in state.tensors.keys()}
)
empty_ranks: Set[int] = set()
for name in all_tensor_states:

View File

@ -264,7 +264,7 @@ class Tracer(TracerBase):
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
if not name.startswith("_") and callable(value)
}
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
# Python modules to apply autowrap to at the start, in addition to
# modules we see while tracing

View File

@ -3611,8 +3611,8 @@ def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
torch = kwargs.get('torch', globals()['torch'])
dtype = kwargs.get('dtype', torch.double)
device = kwargs.get('device', 'cpu')
data = dict([((i, i), float(i + 1) / matrix_size)
for i in range(matrix_size)])
data = {(i, i): float(i + 1) / matrix_size
for i in range(matrix_size)}
def multiply(data, N, i, j, cs, sn, left=True):

View File

@ -377,29 +377,25 @@ def gen_dispatchkey_nativefunc_headers(
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
# Sort for deterministic output.
backend_declarations = list(
sorted(
set(
concatMap(
lambda f: dest.compute_native_function_declaration(
f, backend_indices[backend_dispatch_key]
),
grouped_native_functions,
)
backend_declarations = sorted(
set(
concatMap(
lambda f: dest.compute_native_function_declaration(
f, backend_indices[backend_dispatch_key]
),
grouped_native_functions,
)
)
)
autograd_declarations = list(
sorted(
set(
concatMap(
lambda f: []
if autograd_dispatch_key is None
else dest.compute_native_function_declaration(
f, backend_indices[autograd_dispatch_key]
),
grouped_native_functions,
)
autograd_declarations = sorted(
set(
concatMap(
lambda f: []
if autograd_dispatch_key is None
else dest.compute_native_function_declaration(
f, backend_indices[autograd_dispatch_key]
),
grouped_native_functions,
)
)
)

View File

@ -1058,7 +1058,7 @@ class NativeFunctionsGroup:
for f in self.functions():
expected_generated_fns.update(str(op) for op in f.autogen)
expected_generated_fns_str = ", ".join(
str(x) for x in sorted(list(expected_generated_fns))
str(x) for x in sorted(expected_generated_fns)
)
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
raise RuntimeError(

View File

@ -231,7 +231,7 @@ class SelectiveBuilder:
ret["debug_info"] = sorted(self._debug_info)
ret["kernel_metadata"] = {
k: sorted(list(v)) for (k, v) in self.kernel_metadata.items()
k: sorted(v) for (k, v) in self.kernel_metadata.items()
}
ret["custom_classes"] = sorted(self.custom_classes)