mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dde] use sym_or when checking normalized shape in layer_norm (#160683)
Use `sym_eq` to check equality on tuple of ints/symints ### DDE ``` torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, u1) (unhinted: Eq(u0, u1)). (Size-like symbols: u1, u0) Caused by: return torch.nn.functional.layer_norm( # test/inductor/test_unbacked_symints.py:527 in fn (_refs/__init__.py:3292 in native_layer_norm) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160683 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
f7ad69f59c
commit
a7c75ae976
@ -4511,6 +4511,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
self.assertTrue(torch.allclose(ref[0], actual[0]))
|
||||
self.assertTrue(torch.allclose(ref[1], actual[1]))
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_layer_norm_unbacked_normalized_shape(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, scalar, weight, bias):
|
||||
u1 = scalar.item()
|
||||
y = torch.ones(2, u1)
|
||||
|
||||
return torch.nn.functional.layer_norm(
|
||||
input=y, normalized_shape=(u1,), weight=weight, bias=bias
|
||||
)
|
||||
|
||||
model = MyModel()
|
||||
inputs = (
|
||||
torch.scalar_tensor(16, dtype=torch.int32),
|
||||
torch.randn(16),
|
||||
torch.randn(16),
|
||||
)
|
||||
ep = export(model, inputs)
|
||||
|
||||
actual = ep.module()(*inputs)
|
||||
ref = model(*inputs)
|
||||
self.assertTrue(torch.allclose(ref[0], actual[0]))
|
||||
|
||||
def test_unbacked_3d_matmul(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, repeat):
|
||||
|
||||
@ -3277,6 +3277,8 @@ def native_layer_norm(
|
||||
bias: Optional[Tensor],
|
||||
eps: float,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
from torch.fx.experimental.symbolic_shapes import sym_eq
|
||||
|
||||
normalized_ndim = len(normalized_shape)
|
||||
torch._check(
|
||||
normalized_ndim >= 1,
|
||||
@ -3288,7 +3290,7 @@ def native_layer_norm(
|
||||
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
|
||||
# therefore we use tuple(normalized_shape)
|
||||
torch._check(
|
||||
weight is None or weight.shape == tuple(normalized_shape),
|
||||
weight is None or sym_eq(weight.shape, tuple(normalized_shape)),
|
||||
lambda: "Expected weight to be of same shape as normalized_shape, but got "
|
||||
+ "weight of shape "
|
||||
+ str(weight.shape) # type: ignore[union-attr]
|
||||
@ -3296,7 +3298,7 @@ def native_layer_norm(
|
||||
+ str(normalized_shape),
|
||||
)
|
||||
torch._check(
|
||||
bias is None or bias.shape == tuple(normalized_shape),
|
||||
bias is None or sym_eq(bias.shape, tuple(normalized_shape)),
|
||||
lambda: "Expected bias to be of same shape as normalized_shape, but got "
|
||||
+ "bias of shape "
|
||||
+ str(bias.shape) # type: ignore[union-attr]
|
||||
@ -3305,7 +3307,9 @@ def native_layer_norm(
|
||||
)
|
||||
torch._check(
|
||||
input.ndim >= normalized_ndim
|
||||
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
|
||||
and sym_eq(
|
||||
input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape)
|
||||
),
|
||||
lambda: "Given normalized_shape="
|
||||
+ str(normalized_shape)
|
||||
+ ", expected input with shape "
|
||||
|
||||
Reference in New Issue
Block a user