Make adding Buffers more like adding Parameters (#125971)

Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new Buffer class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the register_buffer method has not been changed. The persistent parameter in the Buffer type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new Buffer type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the Buffer type can be used as a drop in replacement for register_buffer as it just leads to register_buffer being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.

Fixes #35735

Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125971
Approved by: https://github.com/albanD, https://github.com/anijain2305, https://github.com/mlazos
This commit is contained in:
ekamiti
2024-07-31 10:32:37 +00:00
committed by PyTorch MergeBot
parent a94e507c39
commit 9e473fd868
57 changed files with 389 additions and 261 deletions

View File

@ -1250,8 +1250,8 @@ class {test_classname}(torch.nn.Module):
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
self.linear = torch.nn.Linear(2, 2)
self.attr = torch.randn(2)
self.register_buffer("attr2", torch.randn(2))
self.register_buffer("attr3", torch.ones(2, dtype=torch.int32))
self.attr2 = torch.nn.Buffer(torch.randn(2))
self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))
def forward(self, x):
return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))