mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Original commit changeset: 81319beb97f3 Original Phabricator Diff: D47961182 Test Plan: revert to maintain backward compat with legacy ads_dper3 production package. Read details in: S357822 Reviewed By: atuljangra Differential Revision: D48131623 @diff-train-skip-merge (D48131623 landed internally) Pull Request resolved: https://github.com/pytorch/pytorch/pull/106743 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
891bb259f8
commit
bc88028e8e
@ -18,7 +18,7 @@ class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1))
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.foo = 0.0
|
||||
|
||||
def forward(self, x):
|
||||
@ -30,8 +30,8 @@ class MockTiedModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.tied_bias = self.l1.bias
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1))
|
||||
self.tied_buffer = self.buffer
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.register_buffer('tied_buffer', self.buffer)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
|
||||
@ -408,7 +408,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
def test_tied_weights_warns(self, functional_call):
|
||||
module = MockModule()
|
||||
module.tied_bias = module.l1.bias
|
||||
module.tied_buffer = torch.nn.Buffer(module.buffer)
|
||||
module.register_buffer("tied_buffer", module.buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
@ -613,7 +613,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
|
||||
self.register_buffer('foo', torch.tensor([0.0]))
|
||||
|
||||
def forward(self, x):
|
||||
self.foo = self.foo + 1
|
||||
@ -637,7 +637,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
|
||||
self.register_buffer('foo', torch.tensor([0.0]))
|
||||
|
||||
def forward(self, x):
|
||||
self.foo.add_(1)
|
||||
@ -759,7 +759,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1))
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
parameters = tuple(self.parameters())
|
||||
|
Reference in New Issue
Block a user