mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[dde] use sym_or when checking normalized shape in layer_norm (#160683)
Use `sym_eq` to check equality on tuple of ints/symints ### DDE ``` torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, u1) (unhinted: Eq(u0, u1)). (Size-like symbols: u1, u0) Caused by: return torch.nn.functional.layer_norm( # test/inductor/test_unbacked_symints.py:527 in fn (_refs/__init__.py:3292 in native_layer_norm) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160683 Approved by: https://github.com/bobrenjc93
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							f7ad69f59c
						
					
				
				
					commit
					a7c75ae976
				
			| @ -4511,6 +4511,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): | ||||
|         self.assertTrue(torch.allclose(ref[0], actual[0])) | ||||
|         self.assertTrue(torch.allclose(ref[1], actual[1])) | ||||
|  | ||||
|     @torch._dynamo.config.patch(capture_scalar_outputs=True) | ||||
|     def test_layer_norm_unbacked_normalized_shape(self): | ||||
|         class MyModel(torch.nn.Module): | ||||
|             def forward(self, scalar, weight, bias): | ||||
|                 u1 = scalar.item() | ||||
|                 y = torch.ones(2, u1) | ||||
|  | ||||
|                 return torch.nn.functional.layer_norm( | ||||
|                     input=y, normalized_shape=(u1,), weight=weight, bias=bias | ||||
|                 ) | ||||
|  | ||||
|         model = MyModel() | ||||
|         inputs = ( | ||||
|             torch.scalar_tensor(16, dtype=torch.int32), | ||||
|             torch.randn(16), | ||||
|             torch.randn(16), | ||||
|         ) | ||||
|         ep = export(model, inputs) | ||||
|  | ||||
|         actual = ep.module()(*inputs) | ||||
|         ref = model(*inputs) | ||||
|         self.assertTrue(torch.allclose(ref[0], actual[0])) | ||||
|  | ||||
|     def test_unbacked_3d_matmul(self): | ||||
|         class Model(torch.nn.Module): | ||||
|             def forward(self, x, repeat): | ||||
|  | ||||
| @ -3277,6 +3277,8 @@ def native_layer_norm( | ||||
|     bias: Optional[Tensor], | ||||
|     eps: float, | ||||
| ) -> tuple[Tensor, Tensor, Tensor]: | ||||
|     from torch.fx.experimental.symbolic_shapes import sym_eq | ||||
|  | ||||
|     normalized_ndim = len(normalized_shape) | ||||
|     torch._check( | ||||
|         normalized_ndim >= 1, | ||||
| @ -3288,7 +3290,7 @@ def native_layer_norm( | ||||
|     # while torch.Size([1, 2, 3]) == (1, 2, 3) is True | ||||
|     # therefore we use tuple(normalized_shape) | ||||
|     torch._check( | ||||
|         weight is None or weight.shape == tuple(normalized_shape), | ||||
|         weight is None or sym_eq(weight.shape, tuple(normalized_shape)), | ||||
|         lambda: "Expected weight to be of same shape as normalized_shape, but got " | ||||
|         + "weight of shape " | ||||
|         + str(weight.shape)  # type: ignore[union-attr] | ||||
| @ -3296,7 +3298,7 @@ def native_layer_norm( | ||||
|         + str(normalized_shape), | ||||
|     ) | ||||
|     torch._check( | ||||
|         bias is None or bias.shape == tuple(normalized_shape), | ||||
|         bias is None or sym_eq(bias.shape, tuple(normalized_shape)), | ||||
|         lambda: "Expected bias to be of same shape as normalized_shape, but got " | ||||
|         + "bias of shape " | ||||
|         + str(bias.shape)  # type: ignore[union-attr] | ||||
| @ -3305,7 +3307,9 @@ def native_layer_norm( | ||||
|     ) | ||||
|     torch._check( | ||||
|         input.ndim >= normalized_ndim | ||||
|         and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), | ||||
|         and sym_eq( | ||||
|             input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape) | ||||
|         ), | ||||
|         lambda: "Given normalized_shape=" | ||||
|         + str(normalized_shape) | ||||
|         + ", expected input with shape " | ||||
|  | ||||
		Reference in New Issue
	
	Block a user