Files
pytorch/test/jit/test_attr.py
Anthony Barbier bf7e290854 Add __main__ guards to jit tests (#154725)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In jit tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725
Approved by: https://github.com/clee2000
2025-06-16 10:28:45 +00:00

66 lines
2.1 KiB
Python

# Owner(s): ["oncall: jit"]
from typing import NamedTuple, Tuple
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
class TestGetDefaultAttr(JitTestCase):
def test_getattr_with_default(self):
class A(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.init_attr_val = 1.0
def forward(self, x):
y = getattr(self, "init_attr_val") # noqa: B009
w: list[float] = [1.0]
z = getattr(self, "missing", w) # noqa: B009
z.append(y)
return z
result = A().forward(0.0)
self.assertEqual(2, len(result))
graph = torch.jit.script(A()).graph
# The "init_attr_val" attribute exists
FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph)
# The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST
FileCheck().check_not("missing").run(graph)
# instead the getattr call will emit the default value, which is a list with one float element
FileCheck().check("float[] = prim::ListConstruct").run(graph)
def test_getattr_named_tuple(self):
global MyTuple
class MyTuple(NamedTuple):
x: str
y: torch.Tensor
def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]:
return (
getattr(x, "x", "fdsa"),
getattr(x, "y", torch.ones((3, 3))),
getattr(x, "z", 7),
)
inp = MyTuple(x="test", y=torch.ones(3, 3) * 2)
ref = fn(inp)
fn_s = torch.jit.script(fn)
res = fn_s(inp)
self.assertEqual(res, ref)
def test_getattr_tuple(self):
def fn(x: Tuple[str, int]) -> int:
return getattr(x, "x", 2)
with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"):
torch.jit.script(fn)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")