Allow register_buffer with Tensor-like object (#159455)

As torch allows extending the tensor with `__torch_function__`, it would be desirable to allow registering it as a buffer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159455
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
linmin
2025-08-01 15:31:38 +00:00
committed by PyTorch MergeBot
parent 7c37b8e1e0
commit 2a286cbdf4
2 changed files with 23 additions and 2 deletions

View File

@ -579,6 +579,22 @@ class TestNN(NNTestCase):
m.buffer_name = Buffer(buffer3)
self.assertEqual(m.buffer_name, Buffer(buffer3))
def test_register_buffer_allows_tensor_like_object(self):
class TensorLike:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
raise NotImplementedError(f"TensorLike.__torch_function__: {func}")
buffer1 = TensorLike()
buffer2 = TensorLike()
m = nn.Module()
m.register_buffer('buffer_name', buffer1)
self.assertEqual(m.buffer_name, buffer1)
self.assertEqual(m.get_buffer('buffer_name'), buffer1)
m.buffer_name = buffer2
self.assertEqual(m.buffer_name, buffer2)
self.assertEqual(m.get_buffer('buffer_name'), buffer2)
def test_get_buffer(self):
m = nn.Module()
buffer1 = torch.randn(2, 3)