Files
pytorch/test/dynamo/test_buffers_override.py
Sean McGovern 297805fd8f Typo fixes for "overridden" in comments and function names (#155944)
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944
Approved by: https://github.com/Skylion007
2025-06-14 03:37:38 +00:00

50 lines
1.6 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch.nn as nn
class TestBuffersOverride(torch._dynamo.test_case.TestCase):
def test_buffers_override(self):
class SomeModel(nn.Module):
def __init__(self):
super().__init__()
# Override buffers; should not cause breakage
# this is because we use `named_buffers` for
# static marking
self.register_buffer("A", torch.ones(3, 3))
self.buffers = []
def forward(self):
return self.A * torch.zeros(1, 1)
model = SomeModel().to(torch.device("cpu"))
compiled_model = torch.compile(model)
self.assertEqual(compiled_model.A, torch.ones(3, 3))
compiled_model()
def test_named_buffers_override(self):
class SomeModel(nn.Module):
def __init__(self):
super().__init__()
# Override buffers; should not cause breakage
# but skip the marking static here since
# named_buffers is overridden
self.register_buffer("B", torch.ones(3, 3))
self.named_buffers = []
def forward(self):
return self.B * torch.zeros(1, 1)
model = SomeModel().to(torch.device("cpu"))
compiled_model = torch.compile(model)
self.assertEqual(compiled_model.B, torch.ones(3, 3))
compiled_model()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()