[BE]: Enable ruff SLOT checks (#146276)

This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276
Approved by: https://github.com/aorenste
This commit is contained in:
Aaron Gokaslan
2025-02-04 19:18:21 +00:00
committed by PyTorch MergeBot
parent 3525b834f0
commit 7f65a20884
16 changed files with 32 additions and 7 deletions

View File

@ -30,6 +30,8 @@ UNZIPPED_CACHES = "unzipped-caches"
# Since the pr identifier can be based on include user defined text (like a branch name) # Since the pr identifier can be based on include user defined text (like a branch name)
# we hash it to sanitize the input and avoid corner cases # we hash it to sanitize the input and avoid corner cases
class PRIdentifier(str): class PRIdentifier(str):
__slots__ = ()
def __new__(cls, value: str) -> "PRIdentifier": def __new__(cls, value: str) -> "PRIdentifier":
md5 = hashlib.md5(value.encode("utf-8")).hexdigest() md5 = hashlib.md5(value.encode("utf-8")).hexdigest()
return super().__new__(cls, md5) return super().__new__(cls, md5)

View File

@ -139,6 +139,8 @@ def seq(a, b):
class isin: class isin:
__slots__ = ()
def __contains__(self, item): def __contains__(self, item):
for x in self: for x in self:
if seq(item, x): if seq(item, x):
@ -153,11 +155,11 @@ class isin:
class llist(isin, list): class llist(isin, list):
pass __slots__ = ()
class ltuple(isin, tuple): class ltuple(isin, tuple):
pass __slots__ = ()
empty_dict = {} empty_dict = {}

View File

@ -148,6 +148,7 @@ select = [
"RUF019", # unnecessary-key-check "RUF019", # unnecessary-key-check
"RUF024", # from keys mutable "RUF024", # from keys mutable
"RUF026", # default factory kwarg "RUF026", # default factory kwarg
"SLOT",
"TCH", "TCH",
"TRY002", # ban vanilla raise (todo fix NOQAs) "TRY002", # ban vanilla raise (todo fix NOQAs)
"TRY203", "TRY203",

View File

@ -4448,7 +4448,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
self.assertEqual(fn(inputs, x), opt_fn(inputs, x)) self.assertEqual(fn(inputs, x), opt_fn(inputs, x))
def test_udf_tuple(self): def test_udf_tuple(self):
class MyTuple(tuple): class MyTuple(tuple): # noqa: SLOT001
def len_mulitply_2(self): def len_mulitply_2(self):
return len(self) * 2 return len(self) * 2
@ -4475,7 +4475,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
self.assertTrue(res_tup.checked) self.assertTrue(res_tup.checked)
def test_udf_tuple_reconstruction(self): def test_udf_tuple_reconstruction(self):
class MyTuple(tuple): class MyTuple(tuple): # noqa: SLOT001
pass pass
def fn(x, klass): def fn(x, klass):

View File

@ -413,7 +413,7 @@ class TestIndexing(TestCase):
# A tuple subclass should also be an nd-index # A tuple subclass should also be an nd-index
class TupleSubclass(tuple): class TupleSubclass(tuple):
pass __slots__ = ()
index = ([1], [1]) index = ([1], [1])
index = TupleSubclass(index) index = TupleSubclass(index)

View File

@ -105,6 +105,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_leaves"] __all__ += ["tree_leaves"]
class _Asterisk(str): class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self: def __new__(cls) -> Self:
return super().__new__(cls, "*") return super().__new__(cls, "*")

View File

@ -5,6 +5,7 @@ from dataclasses import fields
class _UnionTag(str): class _UnionTag(str):
__slots__ = ("_cls",)
_cls: Hashable _cls: Hashable
@staticmethod @staticmethod

View File

@ -696,6 +696,8 @@ def render_call(fn, args, kwargs):
class KeyErrorMessage(str): class KeyErrorMessage(str):
r"""str subclass that returns itself in repr""" r"""str subclass that returns itself in repr"""
__slots__ = ()
def __repr__(self): def __repr__(self):
return self return self

View File

@ -266,6 +266,8 @@ class EqualizationQConfig(
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
""" """
__slots__ = ()
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
raise ValueError( raise ValueError(

View File

@ -102,6 +102,8 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
""" """
__slots__ = ()
def __new__(cls, activation, weight): def __new__(cls, activation, weight):
# catch common mistakes # catch common mistakes
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
@ -133,6 +135,8 @@ class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
""" """
__slots__ = ()
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
# catch common mistakes # catch common mistakes
if isinstance(weight, nn.Module): if isinstance(weight, nn.Module):

View File

@ -231,7 +231,8 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
return reduceOp not in denyList return reduceOp not in denyList
class Backend(str): # TODO refactor into enum/strenum
class Backend(str): # noqa: SLOT000
""" """
An enum-like class for backends. An enum-like class for backends.

View File

@ -18,6 +18,8 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
modified: A flag for if the pass has modified the graph module modified: A flag for if the pass has modified the graph module
""" """
__slots__ = ()
def __new__(cls, graph_module, modified): def __new__(cls, graph_module, modified):
return super().__new__(cls, graph_module, modified) return super().__new__(cls, graph_module, modified)

View File

@ -40,6 +40,8 @@ T = TypeVar("T", bound="Module")
class _IncompatibleKeys( class _IncompatibleKeys(
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
): ):
__slots__ = ()
def __repr__(self): def __repr__(self):
if not self.missing_keys and not self.unexpected_keys: if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>" return "<All keys matched successfully>"

View File

@ -17,6 +17,8 @@ def _validate_dtypes(*dtypes):
# class for tuples corresponding to a PyTorch dispatch macro # class for tuples corresponding to a PyTorch dispatch macro
class _dispatch_dtypes(tuple): class _dispatch_dtypes(tuple):
__slots__ = ()
def __add__(self, other): def __add__(self, other):
assert isinstance(other, tuple) assert isinstance(other, tuple)
return _dispatch_dtypes(tuple.__add__(self, other)) return _dispatch_dtypes(tuple.__add__(self, other))

View File

@ -33,7 +33,7 @@ def unpack_variables(args):
return args return args
class dont_convert(tuple): class dont_convert(tuple):
pass __slots__ = ()
non_differentiable = collections.namedtuple('non_differentiable', ['tensor']) non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])

View File

@ -26,6 +26,8 @@ class TorchVersion(str):
TorchVersion('1.10.0a') > '1.2.1' TorchVersion('1.10.0a') > '1.2.1'
""" """
__slots__ = ()
# fully qualified type names here to appease mypy # fully qualified type names here to appease mypy
def _convert_to_version(self, inp: Any) -> Any: def _convert_to_version(self, inp: Any) -> Any:
if isinstance(inp, Version): if isinstance(inp, Version):