mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Add unbacked symints binding in ShapeProp (#144605)
Summary: ShapeProp doesn't know how to propagate unbacked. Patch it up to propagate unbacked symints like PropagateUnbackedSymInts. Test Plan: ``` buck run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_shape_prop_unbacked_sym ``` Differential Revision: D68050073 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144605 Approved by: https://github.com/guowentian, https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
3c55669b88
commit
e15f91337b
@ -1793,6 +1793,25 @@ class TestFX(JitTestCase):
|
||||
if node.op in {'placeholder'}:
|
||||
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
|
||||
|
||||
def test_shape_prop_unbacked_sym(self):
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return torch.nonzero(x)
|
||||
|
||||
inp = (torch.tensor([1, 0, 1, 0]),)
|
||||
gm = torch.export.export(M(), inp).module()
|
||||
fake_inputs = [
|
||||
node.meta.get("val")
|
||||
for node in gm.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
]
|
||||
inp = fake_inputs
|
||||
fake_mode = detect_fake_mode(inp)
|
||||
shape_prop.ShapeProp(gm=gm, fake_mode=fake_mode).propagate(*inp)
|
||||
self.assertEqual(len(fake_mode.shape_env.pending_fresh_unbacked_symbols), 0)
|
||||
|
||||
def test_nn_module_stack(self):
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -154,6 +154,11 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
self.real_module = self.module
|
||||
|
||||
def run_node(self, n: Node) -> Any:
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
compute_unbacked_bindings,
|
||||
rebind_unbacked,
|
||||
)
|
||||
|
||||
try:
|
||||
if self.fake_module is not None:
|
||||
# Hacky swap. Alternatively, we could do this with overriding
|
||||
@ -163,6 +168,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
if self.fake_mode is not None:
|
||||
with self.fake_mode, enable_python_dispatcher():
|
||||
result = super().run_node(n)
|
||||
rebind_unbacked(self.fake_mode.shape_env, n, result)
|
||||
else:
|
||||
result = super().run_node(n)
|
||||
finally:
|
||||
@ -187,6 +193,12 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
if found_tensor:
|
||||
n.meta["tensor_meta"] = meta
|
||||
|
||||
if self.fake_mode:
|
||||
if (shape_env := self.fake_mode.shape_env) and (
|
||||
symbol_to_path := compute_unbacked_bindings(shape_env, result)
|
||||
):
|
||||
n.meta["unbacked_bindings"] = symbol_to_path
|
||||
|
||||
n.meta["type"] = type(result)
|
||||
return result
|
||||
|
||||
|
Reference in New Issue
Block a user