[BE]: Enable ruff SLOT checks (#146276)

This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276
Approved by: https://github.com/aorenste
This commit is contained in:
Aaron Gokaslan
2025-02-04 19:18:21 +00:00
committed by PyTorch MergeBot
parent 3525b834f0
commit 7f65a20884
16 changed files with 32 additions and 7 deletions

View File

@ -30,6 +30,8 @@ UNZIPPED_CACHES = "unzipped-caches"
# Since the pr identifier can be based on include user defined text (like a branch name)
# we hash it to sanitize the input and avoid corner cases
class PRIdentifier(str):
__slots__ = ()
def __new__(cls, value: str) -> "PRIdentifier":
md5 = hashlib.md5(value.encode("utf-8")).hexdigest()
return super().__new__(cls, md5)

View File

@ -139,6 +139,8 @@ def seq(a, b):
class isin:
__slots__ = ()
def __contains__(self, item):
for x in self:
if seq(item, x):
@ -153,11 +155,11 @@ class isin:
class llist(isin, list):
pass
__slots__ = ()
class ltuple(isin, tuple):
pass
__slots__ = ()
empty_dict = {}

View File

@ -148,6 +148,7 @@ select = [
"RUF019", # unnecessary-key-check
"RUF024", # from keys mutable
"RUF026", # default factory kwarg
"SLOT",
"TCH",
"TRY002", # ban vanilla raise (todo fix NOQAs)
"TRY203",

View File

@ -4448,7 +4448,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
self.assertEqual(fn(inputs, x), opt_fn(inputs, x))
def test_udf_tuple(self):
class MyTuple(tuple):
class MyTuple(tuple): # noqa: SLOT001
def len_mulitply_2(self):
return len(self) * 2
@ -4475,7 +4475,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
self.assertTrue(res_tup.checked)
def test_udf_tuple_reconstruction(self):
class MyTuple(tuple):
class MyTuple(tuple): # noqa: SLOT001
pass
def fn(x, klass):

View File

@ -413,7 +413,7 @@ class TestIndexing(TestCase):
# A tuple subclass should also be an nd-index
class TupleSubclass(tuple):
pass
__slots__ = ()
index = ([1], [1])
index = TupleSubclass(index)

View File

@ -105,6 +105,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_leaves"]
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")

View File

@ -5,6 +5,7 @@ from dataclasses import fields
class _UnionTag(str):
__slots__ = ("_cls",)
_cls: Hashable
@staticmethod

View File

@ -696,6 +696,8 @@ def render_call(fn, args, kwargs):
class KeyErrorMessage(str):
r"""str subclass that returns itself in repr"""
__slots__ = ()
def __repr__(self):
return self

View File

@ -266,6 +266,8 @@ class EqualizationQConfig(
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
"""
__slots__ = ()
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
raise ValueError(

View File

@ -102,6 +102,8 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
"""
__slots__ = ()
def __new__(cls, activation, weight):
# catch common mistakes
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
@ -133,6 +135,8 @@ class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
"""
__slots__ = ()
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
# catch common mistakes
if isinstance(weight, nn.Module):

View File

@ -231,7 +231,8 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
return reduceOp not in denyList
class Backend(str):
# TODO refactor into enum/strenum
class Backend(str): # noqa: SLOT000
"""
An enum-like class for backends.

View File

@ -18,6 +18,8 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
modified: A flag for if the pass has modified the graph module
"""
__slots__ = ()
def __new__(cls, graph_module, modified):
return super().__new__(cls, graph_module, modified)

View File

@ -40,6 +40,8 @@ T = TypeVar("T", bound="Module")
class _IncompatibleKeys(
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
):
__slots__ = ()
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"

View File

@ -17,6 +17,8 @@ def _validate_dtypes(*dtypes):
# class for tuples corresponding to a PyTorch dispatch macro
class _dispatch_dtypes(tuple):
__slots__ = ()
def __add__(self, other):
assert isinstance(other, tuple)
return _dispatch_dtypes(tuple.__add__(self, other))

View File

@ -33,7 +33,7 @@ def unpack_variables(args):
return args
class dont_convert(tuple):
pass
__slots__ = ()
non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])

View File

@ -26,6 +26,8 @@ class TorchVersion(str):
TorchVersion('1.10.0a') > '1.2.1'
"""
__slots__ = ()
# fully qualified type names here to appease mypy
def _convert_to_version(self, inp: Any) -> Any:
if isinstance(inp, Version):