[inductor] support linear & layer_norm unbacked (#155267)

### What
- Use `statically_known_true` over `guard_size_oblivious` in cases where we're checking an optimization path. Otherwise, it will DDE and we can't take the safe/slower path.
- For broadcast checks, use `fallback=False` if we encounter a DDE. Typically, unbackeds would be ≥2 and that falls inline with size-oblivious reasoning (i.e. when `size_oblivious=True`).

### Example DDE
```
torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)).  (Size-like symbols: u0)

Caused by: (_inductor/lowering.py:488 in broadcast_symbolic_shapes)
```
```
torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)).  (Size-like symbols: u0)

Caused by: (_inductor/ir.py:2797 in create)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155267
Approved by: https://github.com/eellison
This commit is contained in:
Colin Peppler
2025-07-22 13:34:32 -07:00
committed by PyTorch MergeBot
parent be72bcf828
commit a6b7bea244
4 changed files with 44 additions and 25 deletions

View File

@ -515,6 +515,37 @@ class TestUnbackedSymints(InductorTestCase):
x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)
torch.compile(fn, fullgraph=True)(x)
@skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_unbacked_linear_layer_norm_input(self, device):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(387, 128, bias=True, device=device)
self.layer_norm1 = torch.nn.LayerNorm(387, device=device)
self.layer_norm2 = torch.nn.LayerNorm(128, device=device)
def forward(self, x, mask):
masked_select = x.masked_select(mask)
view = masked_select.view(-1, 387)
linear = self.linear(view)
layer_norm1 = self.layer_norm1(view)
layer_norm2 = self.layer_norm2(linear)
return linear, layer_norm1, layer_norm2
model = MyModel()
inputs = (
torch.randn((256, 387), dtype=torch.float, device=device),
torch.randint(
low=0, high=2, size=(256, 1), dtype=torch.bool, device=device
),
)
actual = torch.compile(model, fullgraph=True)(*inputs)
expected = model(*inputs)
torch.testing.assert_close(actual, expected)
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)

View File

@ -1667,9 +1667,9 @@ def native_layer_norm_backward(
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import statically_known_true
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
if statically_known_true(M == 0) or statically_known_true(N == 0):
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,

View File

@ -2876,7 +2876,7 @@ class ExpandView(BaseView):
assert old_size[i] is not None
new_size[i] = old_size[i]
elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(old_size[i], 1), size_oblivious=True
sympy.Eq(old_size[i], 1), fallback_value=False
):
pass
else:
@ -2903,7 +2903,7 @@ class ExpandView(BaseView):
new_stride.append(
stride
if not V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(size, 1), size_oblivious=True
sympy.Eq(size, 1), fallback_value=False
)
else sympy.S.Zero
)

View File

@ -489,11 +489,11 @@ def broadcast_symbolic_shapes(a, b):
output = []
for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
if V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(y, 1), size_oblivious=True
sympy.Eq(y, 1), fallback_value=False
):
output.append(x)
elif V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(x, 1), size_oblivious=True
sympy.Eq(x, 1), fallback_value=False
):
output.append(y)
else:
@ -939,26 +939,14 @@ def broadcast_tensors(*inputs):
outputs = []
for x in inputs:
sizes = x.get_size()
if len(sizes) != len(target) or any(
(
(
V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(a, 1), size_oblivious=True
)
and not V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(b, 1), size_oblivious=True
)
)
or (
not V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(a, 1), size_oblivious=True
)
and V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(b, 1), size_oblivious=True
)
)
def is_length_one(size: sympy.Expr):
return V.graph.sizevars.shape_env.evaluate_expr(
sympy.Eq(size, 1), fallback_value=False
)
for a, b in zip(sizes, target)
if len(sizes) != len(target) or any(
is_length_one(a) != is_length_one(b) for a, b in zip(sizes, target)
):
x = expand(x, target)
outputs.append(x)