mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow register_buffer with Tensor-like object (#159455)
As torch allows extending the tensor with `__torch_function__`, it would be desirable to allow registering it as a buffer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159455 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
@ -579,6 +579,22 @@ class TestNN(NNTestCase):
|
||||
m.buffer_name = Buffer(buffer3)
|
||||
self.assertEqual(m.buffer_name, Buffer(buffer3))
|
||||
|
||||
def test_register_buffer_allows_tensor_like_object(self):
|
||||
class TensorLike:
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError(f"TensorLike.__torch_function__: {func}")
|
||||
|
||||
buffer1 = TensorLike()
|
||||
buffer2 = TensorLike()
|
||||
m = nn.Module()
|
||||
m.register_buffer('buffer_name', buffer1)
|
||||
self.assertEqual(m.buffer_name, buffer1)
|
||||
self.assertEqual(m.get_buffer('buffer_name'), buffer1)
|
||||
m.buffer_name = buffer2
|
||||
self.assertEqual(m.buffer_name, buffer2)
|
||||
self.assertEqual(m.get_buffer('buffer_name'), buffer2)
|
||||
|
||||
def test_get_buffer(self):
|
||||
m = nn.Module()
|
||||
buffer1 = torch.randn(2, 3)
|
||||
|
Reference in New Issue
Block a user