Allow register_buffer with Tensor-like object (#159455)

As torch allows extending the tensor with `__torch_function__`, it would be desirable to allow registering it as a buffer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159455
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
linmin
2025-08-01 15:31:38 +00:00
committed by PyTorch MergeBot
parent 7c37b8e1e0
commit 2a286cbdf4
2 changed files with 23 additions and 2 deletions

View File

@ -579,6 +579,22 @@ class TestNN(NNTestCase):
m.buffer_name = Buffer(buffer3)
self.assertEqual(m.buffer_name, Buffer(buffer3))
def test_register_buffer_allows_tensor_like_object(self):
class TensorLike:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
raise NotImplementedError(f"TensorLike.__torch_function__: {func}")
buffer1 = TensorLike()
buffer2 = TensorLike()
m = nn.Module()
m.register_buffer('buffer_name', buffer1)
self.assertEqual(m.buffer_name, buffer1)
self.assertEqual(m.get_buffer('buffer_name'), buffer1)
m.buffer_name = buffer2
self.assertEqual(m.buffer_name, buffer2)
self.assertEqual(m.get_buffer('buffer_name'), buffer2)
def test_get_buffer(self):
m = nn.Module()
buffer1 = torch.randn(2, 3)

View File

@ -568,7 +568,9 @@ class Module:
raise KeyError('buffer name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._buffers:
raise KeyError(f"attribute '{name}' already exists")
elif tensor is not None and not isinstance(tensor, torch.Tensor):
elif tensor is not None and not (
isinstance(tensor, torch.Tensor) or hasattr(tensor, "__torch_function__")
):
raise TypeError(
f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' "
"(torch Tensor or None required)"
@ -2024,7 +2026,10 @@ class Module:
else:
buffers = self.__dict__.get("_buffers")
if isinstance(value, Buffer) or buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
if value is not None and not (
isinstance(value, torch.Tensor)
or hasattr(value, "__torch_function__")
):
raise TypeError(
f"cannot assign '{torch.typename(value)}' as buffer '{name}' "
"(torch.nn.Buffer, torch.Tensor or None expected)"