mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Enable more flake8-comprehensions checks (#94601)
I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR. This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b31ebf9e4
commit
3d82d8d0ed
2
.flake8
2
.flake8
@ -11,7 +11,7 @@ ignore =
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
B007,B008,
|
||||
# these ignores are from flake8-comprehensions; please fix!
|
||||
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
|
||||
C400,C401,C402,C405,C407
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
torch/utils/cpp_extension.py: B950
|
||||
|
@ -181,7 +181,7 @@ class OperatorInputsMode(TorchDispatchMode):
|
||||
return out
|
||||
|
||||
def log_to_file(self, output_filename, *, skip_non_compute_operators=True):
|
||||
sorted_operators = sorted(list(self.func_db.keys()))
|
||||
sorted_operators = sorted(self.func_db.keys())
|
||||
with open(output_filename, "w") as f:
|
||||
for operator in sorted_operators:
|
||||
if skip_non_compute_operators and non_compute_operator(eval(operator)):
|
||||
|
@ -163,7 +163,7 @@ def tensortype_to_ndarray(tensor_type):
|
||||
|
||||
|
||||
def generate_test_input_data(onnx_model, scale):
|
||||
real_inputs_names = list(set([input.name for input in onnx_model.graph.input]) - set([init.name for init in onnx_model.graph.initializer]))
|
||||
real_inputs_names = list({input.name for input in onnx_model.graph.input} - {init.name for init in onnx_model.graph.initializer})
|
||||
real_inputs = []
|
||||
for name in real_inputs_names:
|
||||
for input in onnx_model.graph.input:
|
||||
|
@ -2297,7 +2297,7 @@ class DistributedDataParallelTest(
|
||||
store=store,
|
||||
)
|
||||
seqs = ["sequence_sequence", "seq", "sequence"]
|
||||
vocab = ["<pad>"] + sorted(set([ch for seq in seqs for ch in seq]))
|
||||
vocab = ["<pad>"] + sorted({ch for seq in seqs for ch in seq})
|
||||
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
|
||||
# Set the seed to make the embedding and LSTM deterministic (even
|
||||
# across ranks since DDP broadcasts parameters from rank 0)
|
||||
|
@ -426,7 +426,7 @@ def remove_torch(name):
|
||||
|
||||
def get_list_of_all_tests():
|
||||
all_tests = list(tested_overridable_outplace_ops.keys())
|
||||
return set([remove_torch(test) for test in all_tests])
|
||||
return {remove_torch(test) for test in all_tests}
|
||||
|
||||
|
||||
mytest = {
|
||||
@ -459,11 +459,11 @@ def get_jvp_coverage(subset=None):
|
||||
supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
|
||||
if op_to_opinfo[fn][0].supports_forward_ad}
|
||||
|
||||
ops = set([remove_torch(test) for test in list(ops_dct.keys())])
|
||||
supports_autograd = set([remove_torch(test)
|
||||
for test in list(supports_autograd_ops_dct.keys())])
|
||||
supports_forward_ad = set([remove_torch(test)
|
||||
for test in list(supports_forwardad_ops_dct.keys())])
|
||||
ops = {remove_torch(test) for test in list(ops_dct.keys())}
|
||||
supports_autograd = {remove_torch(test)
|
||||
for test in list(supports_autograd_ops_dct.keys())}
|
||||
supports_forward_ad = {remove_torch(test)
|
||||
for test in list(supports_forwardad_ops_dct.keys())}
|
||||
assert supports_forward_ad.issubset(supports_autograd)
|
||||
assert supports_autograd.issubset(ops)
|
||||
|
||||
|
@ -169,12 +169,12 @@ class TestPythonKey(AOTTestCase):
|
||||
return torch.tanh(x).sum()
|
||||
|
||||
fx_f = make_fx(grad(f))(torch.randn(5))
|
||||
ops = set([i.target for i in fx_f.graph.nodes])
|
||||
ops = {i.target for i in fx_f.graph.nodes}
|
||||
|
||||
self.assertEqual(torch.ops.aten.tanh_backward in ops, True)
|
||||
|
||||
fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
|
||||
ops = set([i.target for i in fx_f.graph.nodes])
|
||||
ops = {i.target for i in fx_f.graph.nodes}
|
||||
self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
|
||||
|
||||
def test_nnc_jit(self, device):
|
||||
|
@ -18,7 +18,7 @@ class TestMinifier(TestCase):
|
||||
failing_f = make_fx(failing_f)(*inps)
|
||||
|
||||
def has_mul(fx_g, inps):
|
||||
return (torch.ops.aten.mul.Tensor in set([i.target for i in fx_g.graph.nodes]))
|
||||
return (torch.ops.aten.mul.Tensor in (i.target for i in fx_g.graph.nodes))
|
||||
|
||||
min_f, inps = minifier(failing_f, inps, has_mul)
|
||||
self.assertEqual(len(min_f.graph.nodes), 4)
|
||||
@ -74,7 +74,7 @@ class TestMinifier(TestCase):
|
||||
inps = [torch.randn(3), torch.randn(3)]
|
||||
|
||||
def has_add(fx_g, inps):
|
||||
return (torch.ops.aten.add.Tensor in set([i.target for i in fx_g.graph.nodes]))
|
||||
return (torch.ops.aten.add.Tensor in (i.target for i in fx_g.graph.nodes))
|
||||
|
||||
failing_f = make_fx(f)(*inps)
|
||||
min_f, inps = minifier(failing_f, inps, has_add)
|
||||
|
@ -114,7 +114,7 @@ def get_suggested_xfails(base, tests):
|
||||
tests = [test[len(base):] for test in tests if
|
||||
belongs_to_base(test, base)]
|
||||
|
||||
base_tests = set([remove_device_dtype(test) for test in tests])
|
||||
base_tests = {remove_device_dtype(test) for test in tests}
|
||||
tests = set(tests)
|
||||
for base in base_tests:
|
||||
cpu_variant = base + '_cpu_float32'
|
||||
|
@ -226,7 +226,7 @@ class TestList(JitTestCase):
|
||||
self.checkScript(foo2, ())
|
||||
|
||||
def foo3():
|
||||
return list(list("abc"))
|
||||
return list(list("abc")) # noqa: C414
|
||||
|
||||
self.checkScript(foo3, ())
|
||||
FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)
|
||||
|
@ -140,7 +140,7 @@ def calcOpsCoverage(ops):
|
||||
"_coverage": round(coverage, 2),
|
||||
"uncovered_ops": uncovered_ops_dict,
|
||||
"covered_ops": covered_ops_dict,
|
||||
"all_generated_ops": sorted(list(all_generated_ops)),
|
||||
"all_generated_ops": sorted(all_generated_ops),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
@ -40,7 +40,7 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
|
||||
if hasattr(test_suite, "check_dtype"):
|
||||
options.check_dtype = test_suite.check_dtype
|
||||
|
||||
names = set([f.name for f in dataclasses.fields(options)])
|
||||
names = {f.name for f in dataclasses.fields(options)}
|
||||
keywords_to_pop = []
|
||||
for k, v in kwargs.items():
|
||||
if k in names:
|
||||
|
@ -116,7 +116,7 @@ class TestDiGraph(PackageTestCase):
|
||||
|
||||
result = g.all_paths("1", "3")
|
||||
# to get rid of indeterminism
|
||||
actual = set([i.strip("\n") for i in result.split(";")[2:-1]])
|
||||
actual = {i.strip("\n") for i in result.split(";")[2:-1]}
|
||||
expected = {
|
||||
'"2" -> "3"',
|
||||
'"1" -> "7"',
|
||||
|
@ -365,10 +365,10 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
|
||||
# test one line API - out of place version
|
||||
base = AnnotatedSingleLayerLinearModel(qengine)
|
||||
base.qconfig = qconfig
|
||||
keys_before = set(list(base.state_dict().keys()))
|
||||
keys_before = set(base.state_dict().keys())
|
||||
model = quantize(base, test_only_eval_fn, [self.calib_data])
|
||||
checkQuantized(model)
|
||||
keys_after = set(list(base.state_dict().keys()))
|
||||
keys_after = set(base.state_dict().keys())
|
||||
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
|
||||
|
||||
# in-place version
|
||||
@ -1107,10 +1107,10 @@ class TestQuantizeEagerPTQDynamic(QuantizationTestCase):
|
||||
|
||||
# test one line API - out of place version
|
||||
base = SingleLayerLinearDynamicModel()
|
||||
keys_before = set(list(base.state_dict().keys()))
|
||||
keys_before = set(base.state_dict().keys())
|
||||
model = quantize_dynamic(base, qconfig_dict)
|
||||
checkQuantized(model)
|
||||
keys_after = set(list(base.state_dict().keys()))
|
||||
keys_after = set(base.state_dict().keys())
|
||||
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
|
||||
|
||||
# in-place version
|
||||
|
@ -900,7 +900,7 @@ class TestFxModelReportClass(QuantizationTestCase):
|
||||
model_report = ModelReport(model_prep, test_detector_set)
|
||||
|
||||
# make sure internal valid reports matches
|
||||
detector_name_set = set([detector.get_detector_name() for detector in test_detector_set])
|
||||
detector_name_set = {detector.get_detector_name() for detector in test_detector_set}
|
||||
self.assertEqual(model_report.get_desired_reports_names(), detector_name_set)
|
||||
|
||||
# now attempt with no valid reports, should raise error
|
||||
@ -1329,7 +1329,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
||||
mods_to_check = set([nn.Linear, nn.Conv2d])
|
||||
|
||||
# get the set of all nodes in the graph their fqns
|
||||
node_fqns = set([node.target for node in prepared_for_callibrate_model.graph.nodes])
|
||||
node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}
|
||||
|
||||
# there should be 4 node fqns that have the observer inserted
|
||||
correct_number_of_obs_inserted = 4
|
||||
|
@ -167,7 +167,7 @@ class TestNamedTupleAPI(TestCase):
|
||||
ret3 = meth(*op.input)
|
||||
check_namedtuple(ret3, op.names)
|
||||
|
||||
all_covered_operators = set([x for y in operators for x in y.operators])
|
||||
all_covered_operators = {x for y in operators for x in y.operators}
|
||||
|
||||
self.assertEqual(all_operators_with_namedtuple_return, all_covered_operators, textwrap.dedent('''
|
||||
The set of covered operators does not match the `all_operators_with_namedtuple_return` of
|
||||
|
@ -579,7 +579,7 @@ def forward(self, x_1):
|
||||
|
||||
|
||||
gm = make_fx(Emformer())(torch.randn(16, 1, 256))
|
||||
ops = set([n.target for n in gm.graph.nodes if n.op == 'call_function'])
|
||||
ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
|
||||
self.assertEqual(len(ops), 2)
|
||||
|
||||
|
||||
|
@ -264,7 +264,7 @@ class TestSparse(TestSparseBase):
|
||||
else:
|
||||
value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
|
||||
|
||||
new_indices = sorted(list(value_map.keys()))
|
||||
new_indices = sorted(value_map.keys())
|
||||
_new_values = [value_map[idx] for idx in new_indices]
|
||||
if t._values().ndimension() < 2:
|
||||
new_values = t._values().new(_new_values)
|
||||
|
@ -130,13 +130,11 @@ FILENAME_ALLOWLIST = {
|
||||
}
|
||||
|
||||
# Include optimizer code for tracing
|
||||
FILENAME_ALLOWLIST |= set(
|
||||
[
|
||||
inspect.getfile(obj)
|
||||
for obj in torch.optim.__dict__.values()
|
||||
if inspect.isclass(obj)
|
||||
]
|
||||
)
|
||||
FILENAME_ALLOWLIST |= {
|
||||
inspect.getfile(obj)
|
||||
for obj in torch.optim.__dict__.values()
|
||||
if inspect.isclass(obj)
|
||||
}
|
||||
FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
|
||||
|
||||
if HAS_PRIMS_REFS:
|
||||
|
@ -760,7 +760,7 @@ def enum_repr(value):
|
||||
|
||||
|
||||
def dict_param_key_ids(value):
|
||||
return set([id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)])
|
||||
return {id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)}
|
||||
|
||||
|
||||
def dict_const_keys(value):
|
||||
@ -771,7 +771,7 @@ def dict_const_keys_repr(const_keys):
|
||||
if any(isinstance(k, enum.Enum) for k in const_keys):
|
||||
# To workaround repr(Enum) returning invalid global reference before python 3.11
|
||||
# by calling enum_repr and removing quotes to render enum in guard code.
|
||||
const_keys_str = f"{set([enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys])}".replace(
|
||||
const_keys_str = f"{set(enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys)}".replace(
|
||||
"'", ""
|
||||
)
|
||||
else:
|
||||
|
@ -304,17 +304,12 @@ class VariableBuilder:
|
||||
else:
|
||||
return key
|
||||
|
||||
result = dict(
|
||||
[
|
||||
(
|
||||
k,
|
||||
VariableBuilder(
|
||||
self.tx, GetItemSource(self.get_source(), index_source(k))
|
||||
)(value[k]).add_guards(guards),
|
||||
)
|
||||
for k in value.keys()
|
||||
]
|
||||
)
|
||||
result = {
|
||||
k: VariableBuilder(
|
||||
self.tx, GetItemSource(self.get_source(), index_source(k))
|
||||
)(value[k]).add_guards(guards)
|
||||
for k in value.keys()
|
||||
}
|
||||
|
||||
if istype(value, collections.defaultdict):
|
||||
result = DefaultDictVariable(
|
||||
|
@ -393,7 +393,7 @@ def min_cut_rematerialization_partition(
|
||||
for node in joint_module.graph.nodes
|
||||
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
|
||||
)
|
||||
ops_ignored = joint_module_ops - set([str(i) for i in recomputable_ops])
|
||||
ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops}
|
||||
print("Ops banned from rematerialization: ", ops_ignored)
|
||||
print()
|
||||
|
||||
@ -522,8 +522,8 @@ def min_cut_rematerialization_partition(
|
||||
joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
|
||||
if AOT_PARTITIONER_DEBUG:
|
||||
print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9)
|
||||
fw_module_nodes = set([node.name for node in fw_module.graph.nodes if node.op == 'call_function'])
|
||||
bw_module_nodes = set([node.name for node in bw_module.graph.nodes if node.op == 'call_function'])
|
||||
fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'}
|
||||
bw_module_nodes = {node.name for node in bw_module.graph.nodes if node.op == 'call_function'}
|
||||
remat_nodes = fw_module_nodes & bw_module_nodes
|
||||
|
||||
counts = defaultdict(int)
|
||||
|
@ -535,9 +535,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
writes = set(dep.name for dep in node.read_writes.writes)
|
||||
|
||||
def is_materialized(buf):
|
||||
buf_uses = set(
|
||||
[user.node for user in scheduler.name_to_node[buf].users]
|
||||
)
|
||||
buf_uses = {user.node for user in scheduler.name_to_node[buf].users}
|
||||
return len(buf_uses - set(node.snodes)) > 0
|
||||
|
||||
if isinstance(node, FusedSchedulerNode):
|
||||
|
@ -344,7 +344,9 @@ def fresh_inductor_cache(cache_entries=None):
|
||||
|
||||
def argsort(seq):
|
||||
# preserve original order for equal strides
|
||||
return list(reversed(sorted(range(len(seq)), key=seq.__getitem__, reverse=True)))
|
||||
getter = seq.__getitem__
|
||||
a_r = range(len(seq))
|
||||
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
|
||||
|
||||
|
||||
@functools.lru_cache(8)
|
||||
|
@ -120,7 +120,7 @@ class ModelReport:
|
||||
|
||||
# keep the reports private so they can't be modified
|
||||
self._desired_report_detectors = desired_report_detectors
|
||||
self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors])
|
||||
self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors}
|
||||
|
||||
# keep a mapping of desired reports to observers of interest
|
||||
# this is to get the readings, and to remove them, can create a large set
|
||||
|
@ -1598,7 +1598,7 @@ def _all_gather_optim_state(
|
||||
gathered_state: Dict[str, Any] = {}
|
||||
|
||||
all_tensor_states = sorted(
|
||||
set([n for state in object_list for n in state.tensors.keys()])
|
||||
{n for state in object_list for n in state.tensors.keys()}
|
||||
)
|
||||
empty_ranks: Set[int] = set()
|
||||
for name in all_tensor_states:
|
||||
|
@ -264,7 +264,7 @@ class Tracer(TracerBase):
|
||||
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
|
||||
if not name.startswith("_") and callable(value)
|
||||
}
|
||||
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
|
||||
self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
|
||||
|
||||
# Python modules to apply autowrap to at the start, in addition to
|
||||
# modules we see while tracing
|
||||
|
@ -3611,8 +3611,8 @@ def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
|
||||
torch = kwargs.get('torch', globals()['torch'])
|
||||
dtype = kwargs.get('dtype', torch.double)
|
||||
device = kwargs.get('device', 'cpu')
|
||||
data = dict([((i, i), float(i + 1) / matrix_size)
|
||||
for i in range(matrix_size)])
|
||||
data = {(i, i): float(i + 1) / matrix_size
|
||||
for i in range(matrix_size)}
|
||||
|
||||
|
||||
def multiply(data, N, i, j, cs, sn, left=True):
|
||||
|
@ -377,29 +377,25 @@ def gen_dispatchkey_nativefunc_headers(
|
||||
# Convert to a set first to remove duplicate kernel names.
|
||||
# Backends are allowed to repeat kernel names; only generate the declaration once!
|
||||
# Sort for deterministic output.
|
||||
backend_declarations = list(
|
||||
sorted(
|
||||
set(
|
||||
concatMap(
|
||||
lambda f: dest.compute_native_function_declaration(
|
||||
f, backend_indices[backend_dispatch_key]
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
backend_declarations = sorted(
|
||||
set(
|
||||
concatMap(
|
||||
lambda f: dest.compute_native_function_declaration(
|
||||
f, backend_indices[backend_dispatch_key]
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
)
|
||||
)
|
||||
autograd_declarations = list(
|
||||
sorted(
|
||||
set(
|
||||
concatMap(
|
||||
lambda f: []
|
||||
if autograd_dispatch_key is None
|
||||
else dest.compute_native_function_declaration(
|
||||
f, backend_indices[autograd_dispatch_key]
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
autograd_declarations = sorted(
|
||||
set(
|
||||
concatMap(
|
||||
lambda f: []
|
||||
if autograd_dispatch_key is None
|
||||
else dest.compute_native_function_declaration(
|
||||
f, backend_indices[autograd_dispatch_key]
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -1058,7 +1058,7 @@ class NativeFunctionsGroup:
|
||||
for f in self.functions():
|
||||
expected_generated_fns.update(str(op) for op in f.autogen)
|
||||
expected_generated_fns_str = ", ".join(
|
||||
str(x) for x in sorted(list(expected_generated_fns))
|
||||
str(x) for x in sorted(expected_generated_fns)
|
||||
)
|
||||
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
|
||||
raise RuntimeError(
|
||||
|
@ -231,7 +231,7 @@ class SelectiveBuilder:
|
||||
ret["debug_info"] = sorted(self._debug_info)
|
||||
|
||||
ret["kernel_metadata"] = {
|
||||
k: sorted(list(v)) for (k, v) in self.kernel_metadata.items()
|
||||
k: sorted(v) for (k, v) in self.kernel_metadata.items()
|
||||
}
|
||||
|
||||
ret["custom_classes"] = sorted(self.custom_classes)
|
||||
|
Reference in New Issue
Block a user