mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93
57 lines
1.2 KiB
Python
57 lines
1.2 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
from __future__ import annotations # type: ignore[attr-defined]
|
|
|
|
import torch
|
|
from torch.fx import symbolic_trace
|
|
|
|
|
|
class A:
|
|
def __call__(self, x: torch.Tensor):
|
|
return torch.add(x, x)
|
|
|
|
|
|
# No forward references
|
|
class M1(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
|
|
return a(x)
|
|
|
|
|
|
# Forward references
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
|
|
return a(x)
|
|
|
|
|
|
# Non-torch annotation with no internal forward references
|
|
class M3(torch.nn.Module):
|
|
def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
|
|
return a(x[0])
|
|
|
|
|
|
# Non-torch annotation with internal forward references
|
|
class M4(torch.nn.Module):
|
|
def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
|
|
return a(x[0])
|
|
|
|
|
|
x = torch.rand(2, 3)
|
|
|
|
ref = torch.add(x, x)
|
|
|
|
traced1 = symbolic_trace(M1())
|
|
res1 = traced1(x, A())
|
|
assert torch.all(torch.eq(ref, res1))
|
|
|
|
traced2 = symbolic_trace(M2())
|
|
res2 = traced2(x, A())
|
|
assert torch.all(torch.eq(ref, res2))
|
|
|
|
traced3 = symbolic_trace(M3())
|
|
res3 = traced3([x], A())
|
|
assert torch.all(torch.eq(ref, res3))
|
|
|
|
traced4 = symbolic_trace(M4())
|
|
res4 = traced4([x], A())
|
|
assert torch.all(torch.eq(ref, res4))
|