mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make adding Buffers more like adding Parameters (#125971)
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new Buffer class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the register_buffer method has not been changed. The persistent parameter in the Buffer type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new Buffer type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the Buffer type can be used as a drop in replacement for register_buffer as it just leads to register_buffer being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible. Fixes #35735 Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125971 Approved by: https://github.com/albanD, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
a94e507c39
commit
9e473fd868
@ -1250,8 +1250,8 @@ class {test_classname}(torch.nn.Module):
|
||||
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
self.attr = torch.randn(2)
|
||||
self.register_buffer("attr2", torch.randn(2))
|
||||
self.register_buffer("attr3", torch.ones(2, dtype=torch.int32))
|
||||
self.attr2 = torch.nn.Buffer(torch.randn(2))
|
||||
self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))
|
||||
|
Reference in New Issue
Block a user