[Ez][BE]: Enable new stable ruff rules (#129825)

Applies a bunch of new ruff lint rules that are now stable. Some of these improve efficiency or readability. Since I already did passes on the codebase for these when they were in preview, there should be relatively few changes to the codebase. This is just more for future hardening of it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129825
Approved by: https://github.com/XuehaiPan, https://github.com/jansel, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan
2024-07-02 14:47:08 +00:00
committed by PyTorch MergeBot
parent 2926655761
commit 6c2a8b6b38
16 changed files with 45 additions and 38 deletions

View File

@ -226,17 +226,17 @@ def main():
print("-----------------------------------") print("-----------------------------------")
print("PyTorch distributed benchmark suite") print("PyTorch distributed benchmark suite")
print("-----------------------------------") print("-----------------------------------")
print("") print()
print(f"* PyTorch version: {torch.__version__}") print(f"* PyTorch version: {torch.__version__}")
print(f"* CUDA version: {torch.version.cuda}") print(f"* CUDA version: {torch.version.cuda}")
print(f"* Distributed backend: {args.distributed_backend}") print(f"* Distributed backend: {args.distributed_backend}")
print(f"* Maximum bucket size: {args.bucket_size}MB") print(f"* Maximum bucket size: {args.bucket_size}MB")
print("") print()
print("--- nvidia-smi topo -m ---") print("--- nvidia-smi topo -m ---")
print("") print()
print(output[0]) print(output[0])
print("--------------------------") print("--------------------------")
print("") print()
torch.cuda.set_device(dist.get_rank() % 8) torch.cuda.set_device(dist.get_rank() % 8)
device = torch.device("cuda:%d" % (dist.get_rank() % 8)) device = torch.device("cuda:%d" % (dist.get_rank() % 8))

View File

@ -32,7 +32,7 @@ def main():
va = str(ja.get(key, "-")) va = str(ja.get(key, "-"))
vb = str(jb.get(key, "-")) vb = str(jb.get(key, "-"))
print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}") print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}")
print("") print()
ba = ja["benchmark_results"] ba = ja["benchmark_results"]
bb = jb["benchmark_results"] bb = jb["benchmark_results"]
@ -48,13 +48,13 @@ def main():
print(f"Benchmark: {name}") print(f"Benchmark: {name}")
# Print header # Print header
print("") print()
print(f"{'':>10s}", end="") # noqa: E999 print(f"{'':>10s}", end="") # noqa: E999
for _ in [75, 95]: for _ in [75, 95]:
print( print(
f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="" f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
) # noqa: E999 ) # noqa: E999
print("") print()
# Print measurements # Print measurements
for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])): for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
@ -78,8 +78,8 @@ def main():
f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%", f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
end="", end="",
) # noqa: E999 ) # noqa: E999
print("") print()
print("") print()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -318,8 +318,8 @@ class TimmRunner(BenchmarkRunner):
if index < start or index >= end: if index < start or index >= end:
continue continue
if ( if (
not re.search("|".join(args.filter), model_name, re.I) not re.search("|".join(args.filter), model_name, re.IGNORECASE)
or re.search("|".join(args.exclude), model_name, re.I) or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
or model_name in args.exclude_exact or model_name in args.exclude_exact
or model_name in self.skip_models or model_name in self.skip_models
): ):

View File

@ -399,8 +399,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
model_name = os.path.basename(model_path) model_name = os.path.basename(model_path)
if ( if (
not re.search("|".join(args.filter), model_name, re.I) not re.search("|".join(args.filter), model_name, re.IGNORECASE)
or re.search("|".join(args.exclude), model_name, re.I) or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
or model_name in args.exclude_exact or model_name in args.exclude_exact
or model_name in self.skip_models or model_name in self.skip_models
): ):

View File

@ -79,7 +79,7 @@ def test_rnns(
if verbose: if verbose:
print(experim.forward.graph_for(*experim.inputs)) print(experim.forward.graph_for(*experim.inputs))
print("") print()
def test_vl_py(**test_args): def test_vl_py(**test_args):
@ -141,7 +141,7 @@ def test_vl_py(**test_args):
if test_args["verbose"]: if test_args["verbose"]:
print(experim.forward.graph_for(*experim.inputs)) print(experim.forward.graph_for(*experim.inputs))
print("") print()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -97,6 +97,7 @@ select = [
"SIM1", "SIM1",
"W", "W",
# Not included in flake8 # Not included in flake8
"FURB",
"LOG", "LOG",
"NPY", "NPY",
"PERF", "PERF",
@ -113,10 +114,13 @@ select = [
"PLR0133", # constant comparison "PLR0133", # constant comparison
"PLR0206", # property with params "PLR0206", # property with params
"PLR1722", # use sys exit "PLR1722", # use sys exit
"PLR1736", # unnecessary list index
"PLW0129", # assert on string literal "PLW0129", # assert on string literal
"PLW0133", # useless exception statement
"PLW0406", # import self "PLW0406", # import self
"PLW0711", # binary op exception "PLW0711", # binary op exception
"PLW1509", # preexec_fn not safe with threads "PLW1509", # preexec_fn not safe with threads
"PLW2101", # useless lock statement
"PLW3301", # nested min max "PLW3301", # nested min max
"PT006", # TODO: enable more PT rules "PT006", # TODO: enable more PT rules
"PT022", "PT022",
@ -133,6 +137,8 @@ select = [
"RUF016", # type error non-integer index "RUF016", # type error non-integer index
"RUF017", "RUF017",
"RUF018", # no assignment in assert "RUF018", # no assignment in assert
"RUF024", # from keys mutable
"RUF026", # default factory kwarg
"TCH", "TCH",
"TRY002", # ban vanilla raise (todo fix NOQAs) "TRY002", # ban vanilla raise (todo fix NOQAs)
"TRY302", "TRY302",

View File

@ -10,7 +10,7 @@ def check_error(desc, fn, *required_substrings):
print(desc) print(desc)
print("-" * 80) print("-" * 80)
print(error_message) print(error_message)
print("") print()
for sub in required_substrings: for sub in required_substrings:
assert sub in error_message assert sub in error_message
return return

View File

@ -51,7 +51,7 @@ class UpscaleBlock(nn.Module):
class SRResNet(nn.Module): class SRResNet(nn.Module):
def __init__(self, rescale_factor, n_filters, n_blocks): def __init__(self, rescale_factor, n_filters, n_blocks):
super().__init__() super().__init__()
self.rescale_levels = int(math.log(rescale_factor, 2)) self.rescale_levels = int(math.log(rescale_factor, 2)) # noqa: FURB163
self.n_filters = n_filters self.n_filters = n_filters
self.n_blocks = n_blocks self.n_blocks = n_blocks

View File

@ -80,7 +80,7 @@ def evaluate(model, criterion, data_loader):
acc1, acc5 = accuracy(output, target, topk=(1, 5)) acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0)) top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0)) top5.update(acc5[0], image.size(0))
print('') print()
return top1, top5 return top1, top5

View File

@ -1903,8 +1903,8 @@ class TestQuantizedOps(TestCase):
X = np.array(X) X = np.array(X)
scale = 1 scale = 1
H, W = X.shape[-2:] H, W = X.shape[-2:]
output_size_h = output_size_h if (output_size_h <= H) else H output_size_h = min(output_size_h, H)
output_size_w = output_size_w if (output_size_w <= W) else W output_size_w = min(output_size_w, W)
if output_size_h == output_size_w: if output_size_h == output_size_w:
output_size = output_size_h output_size = output_size_h
else: else:
@ -1977,9 +1977,9 @@ class TestQuantizedOps(TestCase):
dim_to_check.append(3) dim_to_check.append(3)
D, H, W = X.shape[-3:] D, H, W = X.shape[-3:]
output_size_d = output_size_d if (output_size_d <= D) else D output_size_d = min(output_size_d, D)
output_size_h = output_size_h if (output_size_h <= H) else H output_size_h = min(output_size_h, H)
output_size_w = output_size_w if (output_size_w <= W) else W output_size_w = min(output_size_w, W)
X = torch.from_numpy(X) X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
@ -2049,9 +2049,9 @@ class TestQuantizedOps(TestCase):
X = np.array(X) X = np.array(X)
scale = 1 scale = 1
D, H, W = X.shape[-3:] D, H, W = X.shape[-3:]
output_size_d = output_size_d if (output_size_d <= D) else D output_size_d = min(output_size_d, D)
output_size_h = output_size_h if (output_size_h <= H) else H output_size_h = min(output_size_h, H)
output_size_w = output_size_w if (output_size_w <= W) else W output_size_w = min(output_size_w, W)
if output_size_d == output_size_h == output_size_w: if output_size_d == output_size_h == output_size_w:
output_size = output_size_h output_size = output_size_h
else: else:

View File

@ -4661,7 +4661,7 @@ Done""",
self.assertEqual(avg.device_time_total, 0) self.assertEqual(avg.device_time_total, 0)
def test_profiler_shapes(self): def test_profiler_shapes(self):
print("") print()
layer1 = torch.nn.Linear(20, 30) layer1 = torch.nn.Linear(20, 30)
layer2 = torch.nn.Linear(30, 40) layer2 = torch.nn.Linear(30, 40)
input = torch.randn(128, 20) input = torch.randn(128, 20)
@ -4683,7 +4683,7 @@ Done""",
self.assertEqual(len(found_indices), len(linear_expected_shapes)) self.assertEqual(len(found_indices), len(linear_expected_shapes))
def test_profiler_aggregation_lstm(self): def test_profiler_aggregation_lstm(self):
print("") print()
rnn = torch.nn.LSTM(10, 20, 2) rnn = torch.nn.LSTM(10, 20, 2)
total_time_s = 0 total_time_s = 0
with profile(record_shapes=True, use_kineto=kineto_available()) as prof: with profile(record_shapes=True, use_kineto=kineto_available()) as prof:

View File

@ -1632,7 +1632,7 @@ def choose_saved_values_set(
for i, txt in enumerate(x_values): for i, txt in enumerate(x_values):
plt.annotate( plt.annotate(
f"{txt:.2f}", f"{txt:.2f}",
(x_values[i], y_values[i]), (txt, y_values[i]),
textcoords="offset points", textcoords="offset points",
xytext=(0, 10), xytext=(0, 10),
ha="center", ha="center",

View File

@ -430,8 +430,8 @@ class ExprPrinter(Printer):
if ( if (
isinstance(string, CSEVariable) isinstance(string, CSEVariable)
or re.match(r"^[a-z0-9_.]+$", string, re.I) or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
or re.match(r"^\([^)]*\)$", string, re.I) or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
or string == "" or string == ""
): ):
return string return string

View File

@ -1966,11 +1966,11 @@ class suppress_warnings:
self._clear_registries() self._clear_registries()
self._tmp_suppressions.append( self._tmp_suppressions.append(
(category, message, re.compile(message, re.I), module, record) (category, message, re.compile(message, re.IGNORECASE), module, record)
) )
else: else:
self._suppressions.append( self._suppressions.append(
(category, message, re.compile(message, re.I), module, record) (category, message, re.compile(message, re.IGNORECASE), module, record)
) )
return record return record
@ -2318,7 +2318,8 @@ def _parse_size(size_str):
} }
size_re = re.compile( size_re = re.compile(
r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), re.I r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())),
re.IGNORECASE,
) )
m = size_re.match(size_str.lower()) m = size_re.match(size_str.lower())

View File

@ -297,12 +297,12 @@ if __name__ == "__main__":
print("-------------------------------------------") print("-------------------------------------------")
print(" Info ") print(" Info ")
print("-------------------------------------------") print("-------------------------------------------")
print("") print()
print(f"* PyTorch version: {torch.__version__}") print(f"* PyTorch version: {torch.__version__}")
print(f"* CUDA version: {torch.version.cuda}") print(f"* CUDA version: {torch.version.cuda}")
print("") print()
print("------------ nvidia-smi topo -m -----------") print("------------ nvidia-smi topo -m -----------")
print("") print()
print(output[0]) print(output[0])
print("-------------------------------------------") print("-------------------------------------------")
print("PyTorch Distributed Benchmark (DDP and RPC)") print("PyTorch Distributed Benchmark (DDP and RPC)")

View File

@ -5018,5 +5018,5 @@ def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=
s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL) s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
if suppress_prefix: if suppress_prefix:
s = re.sub(r"Cannot export model.+\n\n", "", s) s = re.sub(r"Cannot export model.+\n\n", "", s)
s = re.sub(r" +$", "", s, flags=re.M) s = re.sub(r" +$", "", s, flags=re.MULTILINE)
return s return s