[export] Add meta[val] to getattr nodes (#154934)

Fixes [P1830293318](https://www.internalfb.com/intern/paste/P1830293318/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154934
Approved by: https://github.com/yushangdi, https://github.com/muchulee8
This commit is contained in:
angelayi
2025-06-13 05:48:21 +00:00
committed by PyTorch MergeBot
parent 25717da8c8
commit 0860606729
5 changed files with 20 additions and 21 deletions

View File

@ -175,7 +175,7 @@ class GraphModule(torch.nn.Module):
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
mul: "f32[10]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
_tensor_constant0 = self._tensor_constant0
_tensor_constant0: "f32[1]" = self._tensor_constant0
add: "f32[10]" = torch.ops.aten.add.Tensor(mul, _tensor_constant0); mul = _tensor_constant0 = None
return (add,)
""", # NOQA: B950

View File

@ -2858,7 +2858,7 @@ class <lambda>(torch.nn.Module):
cat: "f64[9, 5]" = torch.ops.aten.cat.default([randn, randn_1, randn_2]); randn = randn_1 = randn_2 = None
zeros: "i64[1]" = torch.ops.aten.zeros.default([1], dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
_tensor_constant0 = self._tensor_constant0
_tensor_constant0: "i64[3]" = self._tensor_constant0
lift_fresh_copy: "i64[3]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
cumsum: "i64[3]" = torch.ops.aten.cumsum.default(lift_fresh_copy, 0); lift_fresh_copy = None
cat_1: "i64[4]" = torch.ops.aten.cat.default([zeros, cumsum]); zeros = cumsum = None

View File

@ -113,14 +113,18 @@ def _(x):
class AOTInductorTestsTemplate:
def test_custom_op_add(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aoti_custom_ops.custom_add(x, y)
def __init__(self, device):
super().__init__()
self.device = device
self.w = torch.randn(3, 3, device=device)
m = M().to(device=self.device)
args = (
torch.randn(3, 3, device=self.device),
torch.randn(3, 3, device=self.device),
)
def forward(self, x):
const = torch.tensor([1], device=self.device)
x = torch.ops.aoti_custom_ops.custom_add(x, const)
return torch.ops.aoti_custom_ops.custom_add(x, self.w)
m = M(self.device).to(device=self.device)
args = (torch.randn(3, 3, device=self.device),)
self.check_model(m, args)
def test_custom_op_add_output_path(self) -> None:

View File

@ -93,7 +93,6 @@ from torch.fx.experimental.symbolic_shapes import (
ShapeEnv,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.graph_module import _get_attr
from torch.utils._pytree import TreeSpec
from torch.utils._sympy.value_ranges import ValueRangeError
@ -1809,7 +1808,6 @@ def set_missing_meta_vals(gm, flat_args, num_params_buffers):
# need to have their metadata set before lifting them because it is needed
# for computing the exported program's signature.
index = 0
fake_mode = detect_fake_mode(flat_args)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= num_params_buffers:
@ -1817,16 +1815,6 @@ def set_missing_meta_vals(gm, flat_args, num_params_buffers):
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1
if node.op == "get_attr":
val = _get_attr(gm, node.target)
if isinstance(val, torch.Tensor):
assert "val" not in node.meta, (
f"Found attribute {node.target} that has already been fakified "
"but not yet lifted as an input. This should be impossible because "
"(1) we should have already fakified AND lifted params/buffers "
"(2) we should have NOT yet fakified OR lifted tensor constants. "
)
node.meta["val"] = fake_mode.from_tensor(val, static_shapes=True)
def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:

View File

@ -1146,6 +1146,13 @@ class PythonKeyTracer(Tracer):
stack_trace = traceback.StackSummary.from_list(stack_trace)
node.meta["stack_trace"] = "".join(stack_trace.format()).strip()
if kind == "get_attr":
assert isinstance(target, str)
attr = getattr(self.root, target)
if isinstance(attr, torch.Tensor):
with disable_proxy_modes_tracing():
node.meta["val"] = extract_val(attr)
def map_fn(v: Any) -> Optional[_ExtractValType]:
if not isinstance(v, torch.fx.Node) or "val" not in v.meta:
return None