mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "added persistent option to buffers and namedbuffers (#132994)"
This reverts commit 8707c6dfacaed293ddc40cbb5ecf5841568df0e6. Reverted https://github.com/pytorch/pytorch/pull/132994 on behalf of https://github.com/PaliC due to breaking internal pyre tests ([comment](https://github.com/pytorch/pytorch/pull/132994#issuecomment-2278487672))
This commit is contained in:
@ -369,24 +369,6 @@ class TestNN(NNTestCase):
|
||||
self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
|
||||
["buffer1", "buffer2"])
|
||||
|
||||
# test persistent
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('buffer1', torch.empty(3, 5), persistent=True)
|
||||
self.register_buffer('buffer2', torch.empty(3, 5), persistent=False)
|
||||
|
||||
foo = Foo()
|
||||
# persistent = True
|
||||
self.assertEqual(len(list(foo.buffers(persistent=True))), 1)
|
||||
self.assertEqual(names(foo.named_buffers(persistent=True)), ["buffer1"])
|
||||
# persistent = False
|
||||
self.assertEqual(len(list(foo.buffers(persistent=False))), 1)
|
||||
self.assertEqual(names(foo.named_buffers(persistent=False)), ["buffer2"])
|
||||
# persistent = None
|
||||
self.assertEqual(len(list(foo.buffers(persistent=None))), 2)
|
||||
self.assertEqual(names(foo.named_buffers(persistent=None)), ["buffer1", "buffer2"])
|
||||
|
||||
def test_buffer_bad_module_subclass(self):
|
||||
class MyBadModule(nn.Linear):
|
||||
def __init__(self) -> None:
|
||||
|
@ -412,16 +412,11 @@ class _RemoteModule(nn.Module):
|
||||
) -> Iterator[Tuple[str, Parameter]]:
|
||||
_raise_not_supported(self.named_parameters.__name__)
|
||||
|
||||
def buffers(self, recurse: bool = True, *, persistent: Optional[bool] = None) -> Iterator[Tensor]: # type: ignore[return]
|
||||
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
|
||||
_raise_not_supported(self.buffers.__name__)
|
||||
|
||||
def named_buffers( # type: ignore[return]
|
||||
self,
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True,
|
||||
*,
|
||||
persistent: Optional[bool] = None,
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, Tensor]]:
|
||||
_raise_not_supported(self.named_buffers.__name__)
|
||||
|
||||
|
@ -10,7 +10,6 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
@ -2574,12 +2573,8 @@ class Module:
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _named_members(
|
||||
self,
|
||||
get_members_fn: Callable[[Self], Iterable[Tuple[str, Any]]],
|
||||
prefix="",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterator[Tuple[str, Any]]:
|
||||
self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True
|
||||
):
|
||||
r"""Help yield various names + members of modules."""
|
||||
memo = set()
|
||||
modules = (
|
||||
@ -2654,18 +2649,13 @@ class Module:
|
||||
)
|
||||
yield from gen
|
||||
|
||||
def buffers(
|
||||
self, recurse: bool = True, *, persistent: Optional[bool] = None
|
||||
) -> Iterator[Tensor]:
|
||||
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
|
||||
r"""Return an iterator over module buffers.
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): if ``True``, then yields buffers of this module
|
||||
recurse (bool): if True, then yields buffers of this module
|
||||
and all submodules. Otherwise, yields only buffers that
|
||||
are direct members of this module.
|
||||
persistent (bool, optional): if ``True``, then yields persistent buffers only.
|
||||
If ``False``, then yields only non-persistent buffers.
|
||||
If ``None``, then yields both. Default: ``None``
|
||||
|
||||
Yields:
|
||||
torch.Tensor: module buffer
|
||||
@ -2679,28 +2669,20 @@ class Module:
|
||||
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
|
||||
|
||||
"""
|
||||
for _, buf in self.named_buffers(recurse=recurse, persistent=persistent):
|
||||
for _, buf in self.named_buffers(recurse=recurse):
|
||||
yield buf
|
||||
|
||||
def named_buffers(
|
||||
self,
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True,
|
||||
*,
|
||||
persistent: Optional[bool] = None,
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, Tensor]]:
|
||||
r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
|
||||
|
||||
Args:
|
||||
prefix (str): prefix to prepend to all buffer names.
|
||||
recurse (bool, optional): if ``True``, then yields buffers of this module
|
||||
recurse (bool, optional): if True, then yields buffers of this module
|
||||
and all submodules. Otherwise, yields only buffers that
|
||||
are direct members of this module. Default: ``True``.
|
||||
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Default: ``True``.
|
||||
persistent (bool, optional): if ``True``, then yields persistent buffers only.
|
||||
If ``False``, then yields only non-persistent buffers.
|
||||
If ``None``, then yields both. Default: ``None``.
|
||||
are direct members of this module. Defaults to True.
|
||||
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
|
||||
|
||||
Yields:
|
||||
(str, torch.Tensor): Tuple containing the name and buffer
|
||||
@ -2713,21 +2695,8 @@ class Module:
|
||||
>>> print(buf.size())
|
||||
|
||||
"""
|
||||
|
||||
def _get_members(module):
|
||||
if persistent is None:
|
||||
yield from module._buffers.items()
|
||||
elif persistent:
|
||||
for k, v in module._buffers.items():
|
||||
if k not in module._non_persistent_buffers_set:
|
||||
yield k, v
|
||||
else: # persistent=False
|
||||
for k, v in module._buffers.items():
|
||||
if k in module._non_persistent_buffers_set:
|
||||
yield k, v
|
||||
|
||||
gen = self._named_members(
|
||||
_get_members,
|
||||
lambda module: module._buffers.items(),
|
||||
prefix=prefix,
|
||||
recurse=recurse,
|
||||
remove_duplicate=remove_duplicate,
|
||||
|
Reference in New Issue
Block a user