[inductor] Add unbacked symints binding in ShapeProp (#144605)

Summary: ShapeProp  doesn't know how to propagate unbacked. Patch it up to propagate unbacked symints like PropagateUnbackedSymInts.

Test Plan:
```
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r test_shape_prop_unbacked_sym
```

Differential Revision: D68050073

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144605
Approved by: https://github.com/guowentian, https://github.com/pianpwk
This commit is contained in:
Shangdi Yu
2025-01-13 21:30:20 +00:00
committed by PyTorch MergeBot
parent 3c55669b88
commit e15f91337b
2 changed files with 31 additions and 0 deletions

View File

@ -1793,6 +1793,25 @@ class TestFX(JitTestCase):
if node.op in {'placeholder'}:
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
def test_shape_prop_unbacked_sym(self):
from torch._dynamo.utils import detect_fake_mode
class M(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.nonzero(x)
inp = (torch.tensor([1, 0, 1, 0]),)
gm = torch.export.export(M(), inp).module()
fake_inputs = [
node.meta.get("val")
for node in gm.graph.nodes
if node.op == "placeholder"
]
inp = fake_inputs
fake_mode = detect_fake_mode(inp)
shape_prop.ShapeProp(gm=gm, fake_mode=fake_mode).propagate(*inp)
self.assertEqual(len(fake_mode.shape_env.pending_fresh_unbacked_symbols), 0)
def test_nn_module_stack(self):
class SubModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -154,6 +154,11 @@ class ShapeProp(torch.fx.Interpreter):
self.real_module = self.module
def run_node(self, n: Node) -> Any:
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
rebind_unbacked,
)
try:
if self.fake_module is not None:
# Hacky swap. Alternatively, we could do this with overriding
@ -163,6 +168,7 @@ class ShapeProp(torch.fx.Interpreter):
if self.fake_mode is not None:
with self.fake_mode, enable_python_dispatcher():
result = super().run_node(n)
rebind_unbacked(self.fake_mode.shape_env, n, result)
else:
result = super().run_node(n)
finally:
@ -187,6 +193,12 @@ class ShapeProp(torch.fx.Interpreter):
if found_tensor:
n.meta["tensor_meta"] = meta
if self.fake_mode:
if (shape_env := self.fake_mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(shape_env, result)
):
n.meta["unbacked_bindings"] = symbol_to_path
n.meta["type"] = type(result)
return result