mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: Enable ruff SLOT checks (#146276)
This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
3525b834f0
commit
7f65a20884
2
.github/scripts/pytest_caching_utils.py
vendored
2
.github/scripts/pytest_caching_utils.py
vendored
@ -30,6 +30,8 @@ UNZIPPED_CACHES = "unzipped-caches"
|
|||||||
# Since the pr identifier can be based on include user defined text (like a branch name)
|
# Since the pr identifier can be based on include user defined text (like a branch name)
|
||||||
# we hash it to sanitize the input and avoid corner cases
|
# we hash it to sanitize the input and avoid corner cases
|
||||||
class PRIdentifier(str):
|
class PRIdentifier(str):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __new__(cls, value: str) -> "PRIdentifier":
|
def __new__(cls, value: str) -> "PRIdentifier":
|
||||||
md5 = hashlib.md5(value.encode("utf-8")).hexdigest()
|
md5 = hashlib.md5(value.encode("utf-8")).hexdigest()
|
||||||
return super().__new__(cls, md5)
|
return super().__new__(cls, md5)
|
||||||
|
@ -139,6 +139,8 @@ def seq(a, b):
|
|||||||
|
|
||||||
|
|
||||||
class isin:
|
class isin:
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
for x in self:
|
for x in self:
|
||||||
if seq(item, x):
|
if seq(item, x):
|
||||||
@ -153,11 +155,11 @@ class isin:
|
|||||||
|
|
||||||
|
|
||||||
class llist(isin, list):
|
class llist(isin, list):
|
||||||
pass
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
||||||
class ltuple(isin, tuple):
|
class ltuple(isin, tuple):
|
||||||
pass
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
||||||
empty_dict = {}
|
empty_dict = {}
|
||||||
|
@ -148,6 +148,7 @@ select = [
|
|||||||
"RUF019", # unnecessary-key-check
|
"RUF019", # unnecessary-key-check
|
||||||
"RUF024", # from keys mutable
|
"RUF024", # from keys mutable
|
||||||
"RUF026", # default factory kwarg
|
"RUF026", # default factory kwarg
|
||||||
|
"SLOT",
|
||||||
"TCH",
|
"TCH",
|
||||||
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
||||||
"TRY203",
|
"TRY203",
|
||||||
|
@ -4448,7 +4448,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(fn(inputs, x), opt_fn(inputs, x))
|
self.assertEqual(fn(inputs, x), opt_fn(inputs, x))
|
||||||
|
|
||||||
def test_udf_tuple(self):
|
def test_udf_tuple(self):
|
||||||
class MyTuple(tuple):
|
class MyTuple(tuple): # noqa: SLOT001
|
||||||
def len_mulitply_2(self):
|
def len_mulitply_2(self):
|
||||||
return len(self) * 2
|
return len(self) * 2
|
||||||
|
|
||||||
@ -4475,7 +4475,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertTrue(res_tup.checked)
|
self.assertTrue(res_tup.checked)
|
||||||
|
|
||||||
def test_udf_tuple_reconstruction(self):
|
def test_udf_tuple_reconstruction(self):
|
||||||
class MyTuple(tuple):
|
class MyTuple(tuple): # noqa: SLOT001
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def fn(x, klass):
|
def fn(x, klass):
|
||||||
|
@ -413,7 +413,7 @@ class TestIndexing(TestCase):
|
|||||||
|
|
||||||
# A tuple subclass should also be an nd-index
|
# A tuple subclass should also be an nd-index
|
||||||
class TupleSubclass(tuple):
|
class TupleSubclass(tuple):
|
||||||
pass
|
__slots__ = ()
|
||||||
|
|
||||||
index = ([1], [1])
|
index = ([1], [1])
|
||||||
index = TupleSubclass(index)
|
index = TupleSubclass(index)
|
||||||
|
@ -105,6 +105,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||||||
__all__ += ["tree_leaves"]
|
__all__ += ["tree_leaves"]
|
||||||
|
|
||||||
class _Asterisk(str):
|
class _Asterisk(str):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __new__(cls) -> Self:
|
def __new__(cls) -> Self:
|
||||||
return super().__new__(cls, "*")
|
return super().__new__(cls, "*")
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from dataclasses import fields
|
|||||||
|
|
||||||
|
|
||||||
class _UnionTag(str):
|
class _UnionTag(str):
|
||||||
|
__slots__ = ("_cls",)
|
||||||
_cls: Hashable
|
_cls: Hashable
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -696,6 +696,8 @@ def render_call(fn, args, kwargs):
|
|||||||
class KeyErrorMessage(str):
|
class KeyErrorMessage(str):
|
||||||
r"""str subclass that returns itself in repr"""
|
r"""str subclass that returns itself in repr"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -266,6 +266,8 @@ class EqualizationQConfig(
|
|||||||
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
|
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
|
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
|
||||||
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
|
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -102,6 +102,8 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __new__(cls, activation, weight):
|
def __new__(cls, activation, weight):
|
||||||
# catch common mistakes
|
# catch common mistakes
|
||||||
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
|
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
|
||||||
@ -133,6 +135,8 @@ class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
|
|||||||
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
|
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
|
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
|
||||||
# catch common mistakes
|
# catch common mistakes
|
||||||
if isinstance(weight, nn.Module):
|
if isinstance(weight, nn.Module):
|
||||||
|
@ -231,7 +231,8 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
|
|||||||
return reduceOp not in denyList
|
return reduceOp not in denyList
|
||||||
|
|
||||||
|
|
||||||
class Backend(str):
|
# TODO refactor into enum/strenum
|
||||||
|
class Backend(str): # noqa: SLOT000
|
||||||
"""
|
"""
|
||||||
An enum-like class for backends.
|
An enum-like class for backends.
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
|||||||
modified: A flag for if the pass has modified the graph module
|
modified: A flag for if the pass has modified the graph module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __new__(cls, graph_module, modified):
|
def __new__(cls, graph_module, modified):
|
||||||
return super().__new__(cls, graph_module, modified)
|
return super().__new__(cls, graph_module, modified)
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ T = TypeVar("T", bound="Module")
|
|||||||
class _IncompatibleKeys(
|
class _IncompatibleKeys(
|
||||||
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
||||||
):
|
):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if not self.missing_keys and not self.unexpected_keys:
|
if not self.missing_keys and not self.unexpected_keys:
|
||||||
return "<All keys matched successfully>"
|
return "<All keys matched successfully>"
|
||||||
|
@ -17,6 +17,8 @@ def _validate_dtypes(*dtypes):
|
|||||||
|
|
||||||
# class for tuples corresponding to a PyTorch dispatch macro
|
# class for tuples corresponding to a PyTorch dispatch macro
|
||||||
class _dispatch_dtypes(tuple):
|
class _dispatch_dtypes(tuple):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
assert isinstance(other, tuple)
|
assert isinstance(other, tuple)
|
||||||
return _dispatch_dtypes(tuple.__add__(self, other))
|
return _dispatch_dtypes(tuple.__add__(self, other))
|
||||||
|
@ -33,7 +33,7 @@ def unpack_variables(args):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
class dont_convert(tuple):
|
class dont_convert(tuple):
|
||||||
pass
|
__slots__ = ()
|
||||||
|
|
||||||
non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
|
non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
|
||||||
|
|
||||||
|
@ -26,6 +26,8 @@ class TorchVersion(str):
|
|||||||
TorchVersion('1.10.0a') > '1.2.1'
|
TorchVersion('1.10.0a') > '1.2.1'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
# fully qualified type names here to appease mypy
|
# fully qualified type names here to appease mypy
|
||||||
def _convert_to_version(self, inp: Any) -> Any:
|
def _convert_to_version(self, inp: Any) -> Any:
|
||||||
if isinstance(inp, Version):
|
if isinstance(inp, Version):
|
||||||
|
Reference in New Issue
Block a user