mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Ez][BE]: Enable new stable ruff rules (#129825)
Applies a bunch of new ruff lint rules that are now stable. Some of these improve efficiency or readability. Since I already did passes on the codebase for these when they were in preview, there should be relatively few changes to the codebase. This is just more for future hardening of it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129825 Approved by: https://github.com/XuehaiPan, https://github.com/jansel, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
2926655761
commit
6c2a8b6b38
@ -226,17 +226,17 @@ def main():
|
|||||||
print("-----------------------------------")
|
print("-----------------------------------")
|
||||||
print("PyTorch distributed benchmark suite")
|
print("PyTorch distributed benchmark suite")
|
||||||
print("-----------------------------------")
|
print("-----------------------------------")
|
||||||
print("")
|
print()
|
||||||
print(f"* PyTorch version: {torch.__version__}")
|
print(f"* PyTorch version: {torch.__version__}")
|
||||||
print(f"* CUDA version: {torch.version.cuda}")
|
print(f"* CUDA version: {torch.version.cuda}")
|
||||||
print(f"* Distributed backend: {args.distributed_backend}")
|
print(f"* Distributed backend: {args.distributed_backend}")
|
||||||
print(f"* Maximum bucket size: {args.bucket_size}MB")
|
print(f"* Maximum bucket size: {args.bucket_size}MB")
|
||||||
print("")
|
print()
|
||||||
print("--- nvidia-smi topo -m ---")
|
print("--- nvidia-smi topo -m ---")
|
||||||
print("")
|
print()
|
||||||
print(output[0])
|
print(output[0])
|
||||||
print("--------------------------")
|
print("--------------------------")
|
||||||
print("")
|
print()
|
||||||
|
|
||||||
torch.cuda.set_device(dist.get_rank() % 8)
|
torch.cuda.set_device(dist.get_rank() % 8)
|
||||||
device = torch.device("cuda:%d" % (dist.get_rank() % 8))
|
device = torch.device("cuda:%d" % (dist.get_rank() % 8))
|
||||||
|
@ -32,7 +32,7 @@ def main():
|
|||||||
va = str(ja.get(key, "-"))
|
va = str(ja.get(key, "-"))
|
||||||
vb = str(jb.get(key, "-"))
|
vb = str(jb.get(key, "-"))
|
||||||
print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}")
|
print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}")
|
||||||
print("")
|
print()
|
||||||
|
|
||||||
ba = ja["benchmark_results"]
|
ba = ja["benchmark_results"]
|
||||||
bb = jb["benchmark_results"]
|
bb = jb["benchmark_results"]
|
||||||
@ -48,13 +48,13 @@ def main():
|
|||||||
print(f"Benchmark: {name}")
|
print(f"Benchmark: {name}")
|
||||||
|
|
||||||
# Print header
|
# Print header
|
||||||
print("")
|
print()
|
||||||
print(f"{'':>10s}", end="") # noqa: E999
|
print(f"{'':>10s}", end="") # noqa: E999
|
||||||
for _ in [75, 95]:
|
for _ in [75, 95]:
|
||||||
print(
|
print(
|
||||||
f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
|
f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
|
||||||
) # noqa: E999
|
) # noqa: E999
|
||||||
print("")
|
print()
|
||||||
|
|
||||||
# Print measurements
|
# Print measurements
|
||||||
for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
|
for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
|
||||||
@ -78,8 +78,8 @@ def main():
|
|||||||
f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
|
f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
|
||||||
end="",
|
end="",
|
||||||
) # noqa: E999
|
) # noqa: E999
|
||||||
print("")
|
print()
|
||||||
print("")
|
print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -318,8 +318,8 @@ class TimmRunner(BenchmarkRunner):
|
|||||||
if index < start or index >= end:
|
if index < start or index >= end:
|
||||||
continue
|
continue
|
||||||
if (
|
if (
|
||||||
not re.search("|".join(args.filter), model_name, re.I)
|
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
||||||
or re.search("|".join(args.exclude), model_name, re.I)
|
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
||||||
or model_name in args.exclude_exact
|
or model_name in args.exclude_exact
|
||||||
or model_name in self.skip_models
|
or model_name in self.skip_models
|
||||||
):
|
):
|
||||||
|
@ -399,8 +399,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||||||
|
|
||||||
model_name = os.path.basename(model_path)
|
model_name = os.path.basename(model_path)
|
||||||
if (
|
if (
|
||||||
not re.search("|".join(args.filter), model_name, re.I)
|
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
||||||
or re.search("|".join(args.exclude), model_name, re.I)
|
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
||||||
or model_name in args.exclude_exact
|
or model_name in args.exclude_exact
|
||||||
or model_name in self.skip_models
|
or model_name in self.skip_models
|
||||||
):
|
):
|
||||||
|
@ -79,7 +79,7 @@ def test_rnns(
|
|||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(experim.forward.graph_for(*experim.inputs))
|
print(experim.forward.graph_for(*experim.inputs))
|
||||||
print("")
|
print()
|
||||||
|
|
||||||
|
|
||||||
def test_vl_py(**test_args):
|
def test_vl_py(**test_args):
|
||||||
@ -141,7 +141,7 @@ def test_vl_py(**test_args):
|
|||||||
|
|
||||||
if test_args["verbose"]:
|
if test_args["verbose"]:
|
||||||
print(experim.forward.graph_for(*experim.inputs))
|
print(experim.forward.graph_for(*experim.inputs))
|
||||||
print("")
|
print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -97,6 +97,7 @@ select = [
|
|||||||
"SIM1",
|
"SIM1",
|
||||||
"W",
|
"W",
|
||||||
# Not included in flake8
|
# Not included in flake8
|
||||||
|
"FURB",
|
||||||
"LOG",
|
"LOG",
|
||||||
"NPY",
|
"NPY",
|
||||||
"PERF",
|
"PERF",
|
||||||
@ -113,10 +114,13 @@ select = [
|
|||||||
"PLR0133", # constant comparison
|
"PLR0133", # constant comparison
|
||||||
"PLR0206", # property with params
|
"PLR0206", # property with params
|
||||||
"PLR1722", # use sys exit
|
"PLR1722", # use sys exit
|
||||||
|
"PLR1736", # unnecessary list index
|
||||||
"PLW0129", # assert on string literal
|
"PLW0129", # assert on string literal
|
||||||
|
"PLW0133", # useless exception statement
|
||||||
"PLW0406", # import self
|
"PLW0406", # import self
|
||||||
"PLW0711", # binary op exception
|
"PLW0711", # binary op exception
|
||||||
"PLW1509", # preexec_fn not safe with threads
|
"PLW1509", # preexec_fn not safe with threads
|
||||||
|
"PLW2101", # useless lock statement
|
||||||
"PLW3301", # nested min max
|
"PLW3301", # nested min max
|
||||||
"PT006", # TODO: enable more PT rules
|
"PT006", # TODO: enable more PT rules
|
||||||
"PT022",
|
"PT022",
|
||||||
@ -133,6 +137,8 @@ select = [
|
|||||||
"RUF016", # type error non-integer index
|
"RUF016", # type error non-integer index
|
||||||
"RUF017",
|
"RUF017",
|
||||||
"RUF018", # no assignment in assert
|
"RUF018", # no assignment in assert
|
||||||
|
"RUF024", # from keys mutable
|
||||||
|
"RUF026", # default factory kwarg
|
||||||
"TCH",
|
"TCH",
|
||||||
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
||||||
"TRY302",
|
"TRY302",
|
||||||
|
@ -10,7 +10,7 @@ def check_error(desc, fn, *required_substrings):
|
|||||||
print(desc)
|
print(desc)
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
print(error_message)
|
print(error_message)
|
||||||
print("")
|
print()
|
||||||
for sub in required_substrings:
|
for sub in required_substrings:
|
||||||
assert sub in error_message
|
assert sub in error_message
|
||||||
return
|
return
|
||||||
|
@ -51,7 +51,7 @@ class UpscaleBlock(nn.Module):
|
|||||||
class SRResNet(nn.Module):
|
class SRResNet(nn.Module):
|
||||||
def __init__(self, rescale_factor, n_filters, n_blocks):
|
def __init__(self, rescale_factor, n_filters, n_blocks):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rescale_levels = int(math.log(rescale_factor, 2))
|
self.rescale_levels = int(math.log(rescale_factor, 2)) # noqa: FURB163
|
||||||
self.n_filters = n_filters
|
self.n_filters = n_filters
|
||||||
self.n_blocks = n_blocks
|
self.n_blocks = n_blocks
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ def evaluate(model, criterion, data_loader):
|
|||||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||||
top1.update(acc1[0], image.size(0))
|
top1.update(acc1[0], image.size(0))
|
||||||
top5.update(acc5[0], image.size(0))
|
top5.update(acc5[0], image.size(0))
|
||||||
print('')
|
print()
|
||||||
|
|
||||||
return top1, top5
|
return top1, top5
|
||||||
|
|
||||||
|
@ -1903,8 +1903,8 @@ class TestQuantizedOps(TestCase):
|
|||||||
X = np.array(X)
|
X = np.array(X)
|
||||||
scale = 1
|
scale = 1
|
||||||
H, W = X.shape[-2:]
|
H, W = X.shape[-2:]
|
||||||
output_size_h = output_size_h if (output_size_h <= H) else H
|
output_size_h = min(output_size_h, H)
|
||||||
output_size_w = output_size_w if (output_size_w <= W) else W
|
output_size_w = min(output_size_w, W)
|
||||||
if output_size_h == output_size_w:
|
if output_size_h == output_size_w:
|
||||||
output_size = output_size_h
|
output_size = output_size_h
|
||||||
else:
|
else:
|
||||||
@ -1977,9 +1977,9 @@ class TestQuantizedOps(TestCase):
|
|||||||
dim_to_check.append(3)
|
dim_to_check.append(3)
|
||||||
|
|
||||||
D, H, W = X.shape[-3:]
|
D, H, W = X.shape[-3:]
|
||||||
output_size_d = output_size_d if (output_size_d <= D) else D
|
output_size_d = min(output_size_d, D)
|
||||||
output_size_h = output_size_h if (output_size_h <= H) else H
|
output_size_h = min(output_size_h, H)
|
||||||
output_size_w = output_size_w if (output_size_w <= W) else W
|
output_size_w = min(output_size_w, W)
|
||||||
|
|
||||||
X = torch.from_numpy(X)
|
X = torch.from_numpy(X)
|
||||||
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
|
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
|
||||||
@ -2049,9 +2049,9 @@ class TestQuantizedOps(TestCase):
|
|||||||
X = np.array(X)
|
X = np.array(X)
|
||||||
scale = 1
|
scale = 1
|
||||||
D, H, W = X.shape[-3:]
|
D, H, W = X.shape[-3:]
|
||||||
output_size_d = output_size_d if (output_size_d <= D) else D
|
output_size_d = min(output_size_d, D)
|
||||||
output_size_h = output_size_h if (output_size_h <= H) else H
|
output_size_h = min(output_size_h, H)
|
||||||
output_size_w = output_size_w if (output_size_w <= W) else W
|
output_size_w = min(output_size_w, W)
|
||||||
if output_size_d == output_size_h == output_size_w:
|
if output_size_d == output_size_h == output_size_w:
|
||||||
output_size = output_size_h
|
output_size = output_size_h
|
||||||
else:
|
else:
|
||||||
|
@ -4661,7 +4661,7 @@ Done""",
|
|||||||
self.assertEqual(avg.device_time_total, 0)
|
self.assertEqual(avg.device_time_total, 0)
|
||||||
|
|
||||||
def test_profiler_shapes(self):
|
def test_profiler_shapes(self):
|
||||||
print("")
|
print()
|
||||||
layer1 = torch.nn.Linear(20, 30)
|
layer1 = torch.nn.Linear(20, 30)
|
||||||
layer2 = torch.nn.Linear(30, 40)
|
layer2 = torch.nn.Linear(30, 40)
|
||||||
input = torch.randn(128, 20)
|
input = torch.randn(128, 20)
|
||||||
@ -4683,7 +4683,7 @@ Done""",
|
|||||||
self.assertEqual(len(found_indices), len(linear_expected_shapes))
|
self.assertEqual(len(found_indices), len(linear_expected_shapes))
|
||||||
|
|
||||||
def test_profiler_aggregation_lstm(self):
|
def test_profiler_aggregation_lstm(self):
|
||||||
print("")
|
print()
|
||||||
rnn = torch.nn.LSTM(10, 20, 2)
|
rnn = torch.nn.LSTM(10, 20, 2)
|
||||||
total_time_s = 0
|
total_time_s = 0
|
||||||
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
|
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
|
||||||
|
@ -1632,7 +1632,7 @@ def choose_saved_values_set(
|
|||||||
for i, txt in enumerate(x_values):
|
for i, txt in enumerate(x_values):
|
||||||
plt.annotate(
|
plt.annotate(
|
||||||
f"{txt:.2f}",
|
f"{txt:.2f}",
|
||||||
(x_values[i], y_values[i]),
|
(txt, y_values[i]),
|
||||||
textcoords="offset points",
|
textcoords="offset points",
|
||||||
xytext=(0, 10),
|
xytext=(0, 10),
|
||||||
ha="center",
|
ha="center",
|
||||||
|
@ -430,8 +430,8 @@ class ExprPrinter(Printer):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(string, CSEVariable)
|
isinstance(string, CSEVariable)
|
||||||
or re.match(r"^[a-z0-9_.]+$", string, re.I)
|
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
|
||||||
or re.match(r"^\([^)]*\)$", string, re.I)
|
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
|
||||||
or string == ""
|
or string == ""
|
||||||
):
|
):
|
||||||
return string
|
return string
|
||||||
|
@ -1966,11 +1966,11 @@ class suppress_warnings:
|
|||||||
self._clear_registries()
|
self._clear_registries()
|
||||||
|
|
||||||
self._tmp_suppressions.append(
|
self._tmp_suppressions.append(
|
||||||
(category, message, re.compile(message, re.I), module, record)
|
(category, message, re.compile(message, re.IGNORECASE), module, record)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._suppressions.append(
|
self._suppressions.append(
|
||||||
(category, message, re.compile(message, re.I), module, record)
|
(category, message, re.compile(message, re.IGNORECASE), module, record)
|
||||||
)
|
)
|
||||||
|
|
||||||
return record
|
return record
|
||||||
@ -2318,7 +2318,8 @@ def _parse_size(size_str):
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_re = re.compile(
|
size_re = re.compile(
|
||||||
r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), re.I
|
r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())),
|
||||||
|
re.IGNORECASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
m = size_re.match(size_str.lower())
|
m = size_re.match(size_str.lower())
|
||||||
|
@ -297,12 +297,12 @@ if __name__ == "__main__":
|
|||||||
print("-------------------------------------------")
|
print("-------------------------------------------")
|
||||||
print(" Info ")
|
print(" Info ")
|
||||||
print("-------------------------------------------")
|
print("-------------------------------------------")
|
||||||
print("")
|
print()
|
||||||
print(f"* PyTorch version: {torch.__version__}")
|
print(f"* PyTorch version: {torch.__version__}")
|
||||||
print(f"* CUDA version: {torch.version.cuda}")
|
print(f"* CUDA version: {torch.version.cuda}")
|
||||||
print("")
|
print()
|
||||||
print("------------ nvidia-smi topo -m -----------")
|
print("------------ nvidia-smi topo -m -----------")
|
||||||
print("")
|
print()
|
||||||
print(output[0])
|
print(output[0])
|
||||||
print("-------------------------------------------")
|
print("-------------------------------------------")
|
||||||
print("PyTorch Distributed Benchmark (DDP and RPC)")
|
print("PyTorch Distributed Benchmark (DDP and RPC)")
|
||||||
|
@ -5018,5 +5018,5 @@ def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=
|
|||||||
s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
|
s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
|
||||||
if suppress_prefix:
|
if suppress_prefix:
|
||||||
s = re.sub(r"Cannot export model.+\n\n", "", s)
|
s = re.sub(r"Cannot export model.+\n\n", "", s)
|
||||||
s = re.sub(r" +$", "", s, flags=re.M)
|
s = re.sub(r" +$", "", s, flags=re.MULTILINE)
|
||||||
return s
|
return s
|
||||||
|
Reference in New Issue
Block a user