[dde] use sym_or when checking normalized shape in layer_norm (#160683)

Use `sym_eq` to check equality on tuple of ints/symints

### DDE
```
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, u1) (unhinted: Eq(u0, u1)).  (Size-like symbols: u1, u0)

Caused by: return torch.nn.functional.layer_norm(  # test/inductor/test_unbacked_symints.py:527 in fn (_refs/__init__.py:3292 in native_layer_norm)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160683
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Colin Peppler
2025-08-14 15:20:29 -07:00
committed by PyTorch MergeBot
parent f7ad69f59c
commit a7c75ae976
2 changed files with 30 additions and 3 deletions

View File

@ -4511,6 +4511,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
self.assertTrue(torch.allclose(ref[0], actual[0]))
self.assertTrue(torch.allclose(ref[1], actual[1]))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_layer_norm_unbacked_normalized_shape(self):
class MyModel(torch.nn.Module):
def forward(self, scalar, weight, bias):
u1 = scalar.item()
y = torch.ones(2, u1)
return torch.nn.functional.layer_norm(
input=y, normalized_shape=(u1,), weight=weight, bias=bias
)
model = MyModel()
inputs = (
torch.scalar_tensor(16, dtype=torch.int32),
torch.randn(16),
torch.randn(16),
)
ep = export(model, inputs)
actual = ep.module()(*inputs)
ref = model(*inputs)
self.assertTrue(torch.allclose(ref[0], actual[0]))
def test_unbacked_3d_matmul(self):
class Model(torch.nn.Module):
def forward(self, x, repeat):

View File

@ -3277,6 +3277,8 @@ def native_layer_norm(
bias: Optional[Tensor],
eps: float,
) -> tuple[Tensor, Tensor, Tensor]:
from torch.fx.experimental.symbolic_shapes import sym_eq
normalized_ndim = len(normalized_shape)
torch._check(
normalized_ndim >= 1,
@ -3288,7 +3290,7 @@ def native_layer_norm(
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
# therefore we use tuple(normalized_shape)
torch._check(
weight is None or weight.shape == tuple(normalized_shape),
weight is None or sym_eq(weight.shape, tuple(normalized_shape)),
lambda: "Expected weight to be of same shape as normalized_shape, but got "
+ "weight of shape "
+ str(weight.shape) # type: ignore[union-attr]
@ -3296,7 +3298,7 @@ def native_layer_norm(
+ str(normalized_shape),
)
torch._check(
bias is None or bias.shape == tuple(normalized_shape),
bias is None or sym_eq(bias.shape, tuple(normalized_shape)),
lambda: "Expected bias to be of same shape as normalized_shape, but got "
+ "bias of shape "
+ str(bias.shape) # type: ignore[union-attr]
@ -3305,7 +3307,9 @@ def native_layer_norm(
)
torch._check(
input.ndim >= normalized_ndim
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
and sym_eq(
input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape)
),
lambda: "Given normalized_shape="
+ str(normalized_shape)
+ ", expected input with shape "