mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944 Approved by: https://github.com/Skylion007
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch.nn as nn
|
|
|
|
|
|
class TestBuffersOverride(torch._dynamo.test_case.TestCase):
|
|
def test_buffers_override(self):
|
|
class SomeModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Override buffers; should not cause breakage
|
|
# this is because we use `named_buffers` for
|
|
# static marking
|
|
self.register_buffer("A", torch.ones(3, 3))
|
|
self.buffers = []
|
|
|
|
def forward(self):
|
|
return self.A * torch.zeros(1, 1)
|
|
|
|
model = SomeModel().to(torch.device("cpu"))
|
|
compiled_model = torch.compile(model)
|
|
self.assertEqual(compiled_model.A, torch.ones(3, 3))
|
|
compiled_model()
|
|
|
|
def test_named_buffers_override(self):
|
|
class SomeModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Override buffers; should not cause breakage
|
|
# but skip the marking static here since
|
|
# named_buffers is overridden
|
|
self.register_buffer("B", torch.ones(3, 3))
|
|
self.named_buffers = []
|
|
|
|
def forward(self):
|
|
return self.B * torch.zeros(1, 1)
|
|
|
|
model = SomeModel().to(torch.device("cpu"))
|
|
compiled_model = torch.compile(model)
|
|
self.assertEqual(compiled_model.B, torch.ones(3, 3))
|
|
compiled_model()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|