Revert "added persistent option to buffers and namedbuffers (#132994)"

This reverts commit 8707c6dfacaed293ddc40cbb5ecf5841568df0e6.

Reverted https://github.com/pytorch/pytorch/pull/132994 on behalf of https://github.com/PaliC due to breaking internal pyre tests ([comment](https://github.com/pytorch/pytorch/pull/132994#issuecomment-2278487672))
This commit is contained in:
PyTorch MergeBot
2024-08-09 18:14:51 +00:00
parent 6c012f7217
commit 31ef900a65
3 changed files with 12 additions and 66 deletions

View File

@ -369,24 +369,6 @@ class TestNN(NNTestCase):
self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
["buffer1", "buffer2"])
# test persistent
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('buffer1', torch.empty(3, 5), persistent=True)
self.register_buffer('buffer2', torch.empty(3, 5), persistent=False)
foo = Foo()
# persistent = True
self.assertEqual(len(list(foo.buffers(persistent=True))), 1)
self.assertEqual(names(foo.named_buffers(persistent=True)), ["buffer1"])
# persistent = False
self.assertEqual(len(list(foo.buffers(persistent=False))), 1)
self.assertEqual(names(foo.named_buffers(persistent=False)), ["buffer2"])
# persistent = None
self.assertEqual(len(list(foo.buffers(persistent=None))), 2)
self.assertEqual(names(foo.named_buffers(persistent=None)), ["buffer1", "buffer2"])
def test_buffer_bad_module_subclass(self):
class MyBadModule(nn.Linear):
def __init__(self) -> None:

View File

@ -412,16 +412,11 @@ class _RemoteModule(nn.Module):
) -> Iterator[Tuple[str, Parameter]]:
_raise_not_supported(self.named_parameters.__name__)
def buffers(self, recurse: bool = True, *, persistent: Optional[bool] = None) -> Iterator[Tensor]: # type: ignore[return]
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
_raise_not_supported(self.buffers.__name__)
def named_buffers( # type: ignore[return]
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
*,
persistent: Optional[bool] = None,
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__)

View File

@ -10,7 +10,6 @@ from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
@ -2574,12 +2573,8 @@ class Module:
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _named_members(
self,
get_members_fn: Callable[[Self], Iterable[Tuple[str, Any]]],
prefix="",
recurse: bool = True,
remove_duplicate: bool = True,
) -> Iterator[Tuple[str, Any]]:
self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True
):
r"""Help yield various names + members of modules."""
memo = set()
modules = (
@ -2654,18 +2649,13 @@ class Module:
)
yield from gen
def buffers(
self, recurse: bool = True, *, persistent: Optional[bool] = None
) -> Iterator[Tensor]:
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
r"""Return an iterator over module buffers.
Args:
recurse (bool, optional): if ``True``, then yields buffers of this module
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
persistent (bool, optional): if ``True``, then yields persistent buffers only.
If ``False``, then yields only non-persistent buffers.
If ``None``, then yields both. Default: ``None``
Yields:
torch.Tensor: module buffer
@ -2679,28 +2669,20 @@ class Module:
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
"""
for _, buf in self.named_buffers(recurse=recurse, persistent=persistent):
for _, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
*,
persistent: Optional[bool] = None,
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Tensor]]:
r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool, optional): if ``True``, then yields buffers of this module
recurse (bool, optional): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Default: ``True``.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Default: ``True``.
persistent (bool, optional): if ``True``, then yields persistent buffers only.
If ``False``, then yields only non-persistent buffers.
If ``None``, then yields both. Default: ``None``.
are direct members of this module. Defaults to True.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
@ -2713,21 +2695,8 @@ class Module:
>>> print(buf.size())
"""
def _get_members(module):
if persistent is None:
yield from module._buffers.items()
elif persistent:
for k, v in module._buffers.items():
if k not in module._non_persistent_buffers_set:
yield k, v
else: # persistent=False
for k, v in module._buffers.items():
if k in module._non_persistent_buffers_set:
yield k, v
gen = self._named_members(
_get_members,
lambda module: module._buffers.items(),
prefix=prefix,
recurse=recurse,
remove_duplicate=remove_duplicate,