mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Ez][BE]: Enable new stable ruff rules (#129825)
Applies a bunch of new ruff lint rules that are now stable. Some of these improve efficiency or readability. Since I already did passes on the codebase for these when they were in preview, there should be relatively few changes to the codebase. This is just more for future hardening of it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129825 Approved by: https://github.com/XuehaiPan, https://github.com/jansel, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
2926655761
commit
6c2a8b6b38
@ -226,17 +226,17 @@ def main():
|
||||
print("-----------------------------------")
|
||||
print("PyTorch distributed benchmark suite")
|
||||
print("-----------------------------------")
|
||||
print("")
|
||||
print()
|
||||
print(f"* PyTorch version: {torch.__version__}")
|
||||
print(f"* CUDA version: {torch.version.cuda}")
|
||||
print(f"* Distributed backend: {args.distributed_backend}")
|
||||
print(f"* Maximum bucket size: {args.bucket_size}MB")
|
||||
print("")
|
||||
print()
|
||||
print("--- nvidia-smi topo -m ---")
|
||||
print("")
|
||||
print()
|
||||
print(output[0])
|
||||
print("--------------------------")
|
||||
print("")
|
||||
print()
|
||||
|
||||
torch.cuda.set_device(dist.get_rank() % 8)
|
||||
device = torch.device("cuda:%d" % (dist.get_rank() % 8))
|
||||
|
@ -32,7 +32,7 @@ def main():
|
||||
va = str(ja.get(key, "-"))
|
||||
vb = str(jb.get(key, "-"))
|
||||
print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}")
|
||||
print("")
|
||||
print()
|
||||
|
||||
ba = ja["benchmark_results"]
|
||||
bb = jb["benchmark_results"]
|
||||
@ -48,13 +48,13 @@ def main():
|
||||
print(f"Benchmark: {name}")
|
||||
|
||||
# Print header
|
||||
print("")
|
||||
print()
|
||||
print(f"{'':>10s}", end="") # noqa: E999
|
||||
for _ in [75, 95]:
|
||||
print(
|
||||
f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
|
||||
) # noqa: E999
|
||||
print("")
|
||||
print()
|
||||
|
||||
# Print measurements
|
||||
for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
|
||||
@ -78,8 +78,8 @@ def main():
|
||||
f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
|
||||
end="",
|
||||
) # noqa: E999
|
||||
print("")
|
||||
print("")
|
||||
print()
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -318,8 +318,8 @@ class TimmRunner(BenchmarkRunner):
|
||||
if index < start or index >= end:
|
||||
continue
|
||||
if (
|
||||
not re.search("|".join(args.filter), model_name, re.I)
|
||||
or re.search("|".join(args.exclude), model_name, re.I)
|
||||
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
||||
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
||||
or model_name in args.exclude_exact
|
||||
or model_name in self.skip_models
|
||||
):
|
||||
|
@ -399,8 +399,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
if (
|
||||
not re.search("|".join(args.filter), model_name, re.I)
|
||||
or re.search("|".join(args.exclude), model_name, re.I)
|
||||
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
||||
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
||||
or model_name in args.exclude_exact
|
||||
or model_name in self.skip_models
|
||||
):
|
||||
|
@ -79,7 +79,7 @@ def test_rnns(
|
||||
|
||||
if verbose:
|
||||
print(experim.forward.graph_for(*experim.inputs))
|
||||
print("")
|
||||
print()
|
||||
|
||||
|
||||
def test_vl_py(**test_args):
|
||||
@ -141,7 +141,7 @@ def test_vl_py(**test_args):
|
||||
|
||||
if test_args["verbose"]:
|
||||
print(experim.forward.graph_for(*experim.inputs))
|
||||
print("")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -97,6 +97,7 @@ select = [
|
||||
"SIM1",
|
||||
"W",
|
||||
# Not included in flake8
|
||||
"FURB",
|
||||
"LOG",
|
||||
"NPY",
|
||||
"PERF",
|
||||
@ -113,10 +114,13 @@ select = [
|
||||
"PLR0133", # constant comparison
|
||||
"PLR0206", # property with params
|
||||
"PLR1722", # use sys exit
|
||||
"PLR1736", # unnecessary list index
|
||||
"PLW0129", # assert on string literal
|
||||
"PLW0133", # useless exception statement
|
||||
"PLW0406", # import self
|
||||
"PLW0711", # binary op exception
|
||||
"PLW1509", # preexec_fn not safe with threads
|
||||
"PLW2101", # useless lock statement
|
||||
"PLW3301", # nested min max
|
||||
"PT006", # TODO: enable more PT rules
|
||||
"PT022",
|
||||
@ -133,6 +137,8 @@ select = [
|
||||
"RUF016", # type error non-integer index
|
||||
"RUF017",
|
||||
"RUF018", # no assignment in assert
|
||||
"RUF024", # from keys mutable
|
||||
"RUF026", # default factory kwarg
|
||||
"TCH",
|
||||
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
||||
"TRY302",
|
||||
|
@ -10,7 +10,7 @@ def check_error(desc, fn, *required_substrings):
|
||||
print(desc)
|
||||
print("-" * 80)
|
||||
print(error_message)
|
||||
print("")
|
||||
print()
|
||||
for sub in required_substrings:
|
||||
assert sub in error_message
|
||||
return
|
||||
|
@ -51,7 +51,7 @@ class UpscaleBlock(nn.Module):
|
||||
class SRResNet(nn.Module):
|
||||
def __init__(self, rescale_factor, n_filters, n_blocks):
|
||||
super().__init__()
|
||||
self.rescale_levels = int(math.log(rescale_factor, 2))
|
||||
self.rescale_levels = int(math.log(rescale_factor, 2)) # noqa: FURB163
|
||||
self.n_filters = n_filters
|
||||
self.n_blocks = n_blocks
|
||||
|
||||
|
@ -80,7 +80,7 @@ def evaluate(model, criterion, data_loader):
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
top1.update(acc1[0], image.size(0))
|
||||
top5.update(acc5[0], image.size(0))
|
||||
print('')
|
||||
print()
|
||||
|
||||
return top1, top5
|
||||
|
||||
|
@ -1903,8 +1903,8 @@ class TestQuantizedOps(TestCase):
|
||||
X = np.array(X)
|
||||
scale = 1
|
||||
H, W = X.shape[-2:]
|
||||
output_size_h = output_size_h if (output_size_h <= H) else H
|
||||
output_size_w = output_size_w if (output_size_w <= W) else W
|
||||
output_size_h = min(output_size_h, H)
|
||||
output_size_w = min(output_size_w, W)
|
||||
if output_size_h == output_size_w:
|
||||
output_size = output_size_h
|
||||
else:
|
||||
@ -1977,9 +1977,9 @@ class TestQuantizedOps(TestCase):
|
||||
dim_to_check.append(3)
|
||||
|
||||
D, H, W = X.shape[-3:]
|
||||
output_size_d = output_size_d if (output_size_d <= D) else D
|
||||
output_size_h = output_size_h if (output_size_h <= H) else H
|
||||
output_size_w = output_size_w if (output_size_w <= W) else W
|
||||
output_size_d = min(output_size_d, D)
|
||||
output_size_h = min(output_size_h, H)
|
||||
output_size_w = min(output_size_w, W)
|
||||
|
||||
X = torch.from_numpy(X)
|
||||
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
|
||||
@ -2049,9 +2049,9 @@ class TestQuantizedOps(TestCase):
|
||||
X = np.array(X)
|
||||
scale = 1
|
||||
D, H, W = X.shape[-3:]
|
||||
output_size_d = output_size_d if (output_size_d <= D) else D
|
||||
output_size_h = output_size_h if (output_size_h <= H) else H
|
||||
output_size_w = output_size_w if (output_size_w <= W) else W
|
||||
output_size_d = min(output_size_d, D)
|
||||
output_size_h = min(output_size_h, H)
|
||||
output_size_w = min(output_size_w, W)
|
||||
if output_size_d == output_size_h == output_size_w:
|
||||
output_size = output_size_h
|
||||
else:
|
||||
|
@ -4661,7 +4661,7 @@ Done""",
|
||||
self.assertEqual(avg.device_time_total, 0)
|
||||
|
||||
def test_profiler_shapes(self):
|
||||
print("")
|
||||
print()
|
||||
layer1 = torch.nn.Linear(20, 30)
|
||||
layer2 = torch.nn.Linear(30, 40)
|
||||
input = torch.randn(128, 20)
|
||||
@ -4683,7 +4683,7 @@ Done""",
|
||||
self.assertEqual(len(found_indices), len(linear_expected_shapes))
|
||||
|
||||
def test_profiler_aggregation_lstm(self):
|
||||
print("")
|
||||
print()
|
||||
rnn = torch.nn.LSTM(10, 20, 2)
|
||||
total_time_s = 0
|
||||
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
|
||||
|
@ -1632,7 +1632,7 @@ def choose_saved_values_set(
|
||||
for i, txt in enumerate(x_values):
|
||||
plt.annotate(
|
||||
f"{txt:.2f}",
|
||||
(x_values[i], y_values[i]),
|
||||
(txt, y_values[i]),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 10),
|
||||
ha="center",
|
||||
|
@ -430,8 +430,8 @@ class ExprPrinter(Printer):
|
||||
|
||||
if (
|
||||
isinstance(string, CSEVariable)
|
||||
or re.match(r"^[a-z0-9_.]+$", string, re.I)
|
||||
or re.match(r"^\([^)]*\)$", string, re.I)
|
||||
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
|
||||
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
|
||||
or string == ""
|
||||
):
|
||||
return string
|
||||
|
@ -1966,11 +1966,11 @@ class suppress_warnings:
|
||||
self._clear_registries()
|
||||
|
||||
self._tmp_suppressions.append(
|
||||
(category, message, re.compile(message, re.I), module, record)
|
||||
(category, message, re.compile(message, re.IGNORECASE), module, record)
|
||||
)
|
||||
else:
|
||||
self._suppressions.append(
|
||||
(category, message, re.compile(message, re.I), module, record)
|
||||
(category, message, re.compile(message, re.IGNORECASE), module, record)
|
||||
)
|
||||
|
||||
return record
|
||||
@ -2318,7 +2318,8 @@ def _parse_size(size_str):
|
||||
}
|
||||
|
||||
size_re = re.compile(
|
||||
r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), re.I
|
||||
r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())),
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
m = size_re.match(size_str.lower())
|
||||
|
@ -297,12 +297,12 @@ if __name__ == "__main__":
|
||||
print("-------------------------------------------")
|
||||
print(" Info ")
|
||||
print("-------------------------------------------")
|
||||
print("")
|
||||
print()
|
||||
print(f"* PyTorch version: {torch.__version__}")
|
||||
print(f"* CUDA version: {torch.version.cuda}")
|
||||
print("")
|
||||
print()
|
||||
print("------------ nvidia-smi topo -m -----------")
|
||||
print("")
|
||||
print()
|
||||
print(output[0])
|
||||
print("-------------------------------------------")
|
||||
print("PyTorch Distributed Benchmark (DDP and RPC)")
|
||||
|
@ -5018,5 +5018,5 @@ def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=
|
||||
s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
|
||||
if suppress_prefix:
|
||||
s = re.sub(r"Cannot export model.+\n\n", "", s)
|
||||
s = re.sub(r" +$", "", s, flags=re.M)
|
||||
s = re.sub(r" +$", "", s, flags=re.MULTILINE)
|
||||
return s
|
||||
|
Reference in New Issue
Block a user