Compare commits

...

1 Commits

Author SHA1 Message Date
49e98a5913 Support SymInt placeholder in wrapper fxir (#167757)
Summary:

add support for symint placeholders

added two test cases with dynamic reshape
- dynamic info coming from tmd on placeholders
- dynamic info coming from placeholders (symints)

Test Plan:
test_reshape_dynamic_ph
test_reshape_dynamic_tmd

Reviewed By: blaine-rister

Differential Revision: D86984100
2025-11-14 12:42:11 -08:00
2 changed files with 50 additions and 13 deletions

View File

@ -831,7 +831,9 @@ class AOTFxirTestCase(InductorTestCase):
gm = torch._inductor.aot_compile(
ep.module(), inp, options={"fx_wrapper": True, **test_config}
)
self.assertTrue(same(model(*inp), gm(*inp)))
# Flatten args for fx_wrapper gm
flat_args, _ = pytree.tree_flatten(inp)
self.assertTrue(same(model(*inp), gm(*flat_args)))
for node in gm.graph.nodes:
if (
@ -1182,6 +1184,38 @@ def forward(self, arg0_1, arg1_1, arg2_1):
compiled_out = compiled(*args)
self.assertEqual(compiled_out.shape, shape)
def test_reshape_dynamic_ph(self):
"""
Test dynamic scalars using SymInts placeholder
"""
class TestModule(torch.nn.Module):
def forward(self, x, shape):
return torch.reshape(x, shape) + 2
ds = {
"x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO),
"shape": [torch.export.Dim.AUTO, torch.export.Dim.AUTO],
}
args = (torch.randn((12, 14), device=self.device), [6, 28])
gm = self.check(TestModule(), args, ds)
def test_reshape_dynamic_tmd(self):
"""
Test dynamic reshape using shape dependent information
"""
class TestModule(torch.nn.Module):
def forward(self, x):
new_shape = [x.shape[0] // 2, x.shape[1] * 2]
return torch.reshape(x, new_shape) + 2
ds = {
"x": (torch.export.Dim.AUTO, torch.export.Dim.AUTO),
}
args = (torch.randn((12, 14), device=self.device),)
self.check(TestModule(), args, ds)
class TestReplaceFloorDiv(InductorTestCase):
"""

View File

@ -2537,23 +2537,26 @@ def _extract_inputs_from_exported_gm(
fake_inputs = [
node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder"
]
# Replace non-tensor (constant) inputs with Nones, since these are not being
# used anyways by the graph
fake_inputs = [
inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs
]
if not config.fx_wrapper:
# Replace non-tensor inputs with Nones
# constant scalars embedded in the graph
# symbolic scalars (symint) are not supported in non-fx_wrapper mode
fake_inputs = [
inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs
]
if any(v is not None for v in fake_inputs):
# Validate devices before switching to fake tensors.
for idx, fi, i in zip(count(), fake_inputs, example_inputs_):
if fi is not None:
assert isinstance(i, torch.Tensor)
if fi.device != i.device:
raise ValueError(
f"Device mismatch between fake input and example input at position #{idx}: "
f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
"make sure torch.export() and torch.aot_compile() run on the same device."
)
if isinstance(fi, torch.Tensor) and isinstance(i, torch.Tensor):
if fi.device != i.device:
raise ValueError(
f"Device mismatch between fake input and example input at position #{idx}: "
f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
"make sure torch.export() and torch.aot_compile() run on the same device."
)
return fake_inputs
return example_inputs_