mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE]: Enable ruff SLOT checks (#146276)
This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
3525b834f0
commit
7f65a20884
2
.github/scripts/pytest_caching_utils.py
vendored
2
.github/scripts/pytest_caching_utils.py
vendored
@ -30,6 +30,8 @@ UNZIPPED_CACHES = "unzipped-caches"
|
||||
# Since the pr identifier can be based on include user defined text (like a branch name)
|
||||
# we hash it to sanitize the input and avoid corner cases
|
||||
class PRIdentifier(str):
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, value: str) -> "PRIdentifier":
|
||||
md5 = hashlib.md5(value.encode("utf-8")).hexdigest()
|
||||
return super().__new__(cls, md5)
|
||||
|
@ -139,6 +139,8 @@ def seq(a, b):
|
||||
|
||||
|
||||
class isin:
|
||||
__slots__ = ()
|
||||
|
||||
def __contains__(self, item):
|
||||
for x in self:
|
||||
if seq(item, x):
|
||||
@ -153,11 +155,11 @@ class isin:
|
||||
|
||||
|
||||
class llist(isin, list):
|
||||
pass
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class ltuple(isin, tuple):
|
||||
pass
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
empty_dict = {}
|
||||
|
@ -148,6 +148,7 @@ select = [
|
||||
"RUF019", # unnecessary-key-check
|
||||
"RUF024", # from keys mutable
|
||||
"RUF026", # default factory kwarg
|
||||
"SLOT",
|
||||
"TCH",
|
||||
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
||||
"TRY203",
|
||||
|
@ -4448,7 +4448,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(fn(inputs, x), opt_fn(inputs, x))
|
||||
|
||||
def test_udf_tuple(self):
|
||||
class MyTuple(tuple):
|
||||
class MyTuple(tuple): # noqa: SLOT001
|
||||
def len_mulitply_2(self):
|
||||
return len(self) * 2
|
||||
|
||||
@ -4475,7 +4475,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(res_tup.checked)
|
||||
|
||||
def test_udf_tuple_reconstruction(self):
|
||||
class MyTuple(tuple):
|
||||
class MyTuple(tuple): # noqa: SLOT001
|
||||
pass
|
||||
|
||||
def fn(x, klass):
|
||||
|
@ -413,7 +413,7 @@ class TestIndexing(TestCase):
|
||||
|
||||
# A tuple subclass should also be an nd-index
|
||||
class TupleSubclass(tuple):
|
||||
pass
|
||||
__slots__ = ()
|
||||
|
||||
index = ([1], [1])
|
||||
index = TupleSubclass(index)
|
||||
|
@ -105,6 +105,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
__all__ += ["tree_leaves"]
|
||||
|
||||
class _Asterisk(str):
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls) -> Self:
|
||||
return super().__new__(cls, "*")
|
||||
|
||||
|
@ -5,6 +5,7 @@ from dataclasses import fields
|
||||
|
||||
|
||||
class _UnionTag(str):
|
||||
__slots__ = ("_cls",)
|
||||
_cls: Hashable
|
||||
|
||||
@staticmethod
|
||||
|
@ -696,6 +696,8 @@ def render_call(fn, args, kwargs):
|
||||
class KeyErrorMessage(str):
|
||||
r"""str subclass that returns itself in repr"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return self
|
||||
|
||||
|
@ -266,6 +266,8 @@ class EqualizationQConfig(
|
||||
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
|
||||
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
|
||||
raise ValueError(
|
||||
|
@ -102,6 +102,8 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, activation, weight):
|
||||
# catch common mistakes
|
||||
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
|
||||
@ -133,6 +135,8 @@ class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
|
||||
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
|
||||
# catch common mistakes
|
||||
if isinstance(weight, nn.Module):
|
||||
|
@ -231,7 +231,8 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
|
||||
return reduceOp not in denyList
|
||||
|
||||
|
||||
class Backend(str):
|
||||
# TODO refactor into enum/strenum
|
||||
class Backend(str): # noqa: SLOT000
|
||||
"""
|
||||
An enum-like class for backends.
|
||||
|
||||
|
@ -18,6 +18,8 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
||||
modified: A flag for if the pass has modified the graph module
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, graph_module, modified):
|
||||
return super().__new__(cls, graph_module, modified)
|
||||
|
||||
|
@ -40,6 +40,8 @@ T = TypeVar("T", bound="Module")
|
||||
class _IncompatibleKeys(
|
||||
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
||||
):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
if not self.missing_keys and not self.unexpected_keys:
|
||||
return "<All keys matched successfully>"
|
||||
|
@ -17,6 +17,8 @@ def _validate_dtypes(*dtypes):
|
||||
|
||||
# class for tuples corresponding to a PyTorch dispatch macro
|
||||
class _dispatch_dtypes(tuple):
|
||||
__slots__ = ()
|
||||
|
||||
def __add__(self, other):
|
||||
assert isinstance(other, tuple)
|
||||
return _dispatch_dtypes(tuple.__add__(self, other))
|
||||
|
@ -33,7 +33,7 @@ def unpack_variables(args):
|
||||
return args
|
||||
|
||||
class dont_convert(tuple):
|
||||
pass
|
||||
__slots__ = ()
|
||||
|
||||
non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
|
||||
|
||||
|
@ -26,6 +26,8 @@ class TorchVersion(str):
|
||||
TorchVersion('1.10.0a') > '1.2.1'
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
# fully qualified type names here to appease mypy
|
||||
def _convert_to_version(self, inp: Any) -> Any:
|
||||
if isinstance(inp, Version):
|
||||
|
Reference in New Issue
Block a user