mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow register_buffer with Tensor-like object (#159455)
As torch allows extending the tensor with `__torch_function__`, it would be desirable to allow registering it as a buffer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159455 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
@ -579,6 +579,22 @@ class TestNN(NNTestCase):
|
||||
m.buffer_name = Buffer(buffer3)
|
||||
self.assertEqual(m.buffer_name, Buffer(buffer3))
|
||||
|
||||
def test_register_buffer_allows_tensor_like_object(self):
|
||||
class TensorLike:
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError(f"TensorLike.__torch_function__: {func}")
|
||||
|
||||
buffer1 = TensorLike()
|
||||
buffer2 = TensorLike()
|
||||
m = nn.Module()
|
||||
m.register_buffer('buffer_name', buffer1)
|
||||
self.assertEqual(m.buffer_name, buffer1)
|
||||
self.assertEqual(m.get_buffer('buffer_name'), buffer1)
|
||||
m.buffer_name = buffer2
|
||||
self.assertEqual(m.buffer_name, buffer2)
|
||||
self.assertEqual(m.get_buffer('buffer_name'), buffer2)
|
||||
|
||||
def test_get_buffer(self):
|
||||
m = nn.Module()
|
||||
buffer1 = torch.randn(2, 3)
|
||||
|
@ -568,7 +568,9 @@ class Module:
|
||||
raise KeyError('buffer name can\'t be empty string ""')
|
||||
elif hasattr(self, name) and name not in self._buffers:
|
||||
raise KeyError(f"attribute '{name}' already exists")
|
||||
elif tensor is not None and not isinstance(tensor, torch.Tensor):
|
||||
elif tensor is not None and not (
|
||||
isinstance(tensor, torch.Tensor) or hasattr(tensor, "__torch_function__")
|
||||
):
|
||||
raise TypeError(
|
||||
f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' "
|
||||
"(torch Tensor or None required)"
|
||||
@ -2024,7 +2026,10 @@ class Module:
|
||||
else:
|
||||
buffers = self.__dict__.get("_buffers")
|
||||
if isinstance(value, Buffer) or buffers is not None and name in buffers:
|
||||
if value is not None and not isinstance(value, torch.Tensor):
|
||||
if value is not None and not (
|
||||
isinstance(value, torch.Tensor)
|
||||
or hasattr(value, "__torch_function__")
|
||||
):
|
||||
raise TypeError(
|
||||
f"cannot assign '{torch.typename(value)}' as buffer '{name}' "
|
||||
"(torch.nn.Buffer, torch.Tensor or None expected)"
|
||||
|
Reference in New Issue
Block a user