mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[inductor] support linear & layer_norm unbacked (#155267)
### What - Use `statically_known_true` over `guard_size_oblivious` in cases where we're checking an optimization path. Otherwise, it will DDE and we can't take the safe/slower path. - For broadcast checks, use `fallback=False` if we encounter a DDE. Typically, unbackeds would be ≥2 and that falls inline with size-oblivious reasoning (i.e. when `size_oblivious=True`). ### Example DDE ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/lowering.py:488 in broadcast_symbolic_shapes) ``` ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/ir.py:2797 in create) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155267 Approved by: https://github.com/eellison
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							be72bcf828
						
					
				
				
					commit
					a6b7bea244
				
			| @ -515,6 +515,37 @@ class TestUnbackedSymints(InductorTestCase): | |||||||
|         x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device) |         x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device) | ||||||
|         torch.compile(fn, fullgraph=True)(x) |         torch.compile(fn, fullgraph=True)(x) | ||||||
|  |  | ||||||
|  |     @skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton") | ||||||
|  |     @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) | ||||||
|  |     def test_unbacked_linear_layer_norm_input(self, device): | ||||||
|  |         class MyModel(torch.nn.Module): | ||||||
|  |             def __init__(self): | ||||||
|  |                 super().__init__() | ||||||
|  |                 self.linear = torch.nn.Linear(387, 128, bias=True, device=device) | ||||||
|  |                 self.layer_norm1 = torch.nn.LayerNorm(387, device=device) | ||||||
|  |                 self.layer_norm2 = torch.nn.LayerNorm(128, device=device) | ||||||
|  |  | ||||||
|  |             def forward(self, x, mask): | ||||||
|  |                 masked_select = x.masked_select(mask) | ||||||
|  |                 view = masked_select.view(-1, 387) | ||||||
|  |  | ||||||
|  |                 linear = self.linear(view) | ||||||
|  |                 layer_norm1 = self.layer_norm1(view) | ||||||
|  |                 layer_norm2 = self.layer_norm2(linear) | ||||||
|  |                 return linear, layer_norm1, layer_norm2 | ||||||
|  |  | ||||||
|  |         model = MyModel() | ||||||
|  |         inputs = ( | ||||||
|  |             torch.randn((256, 387), dtype=torch.float, device=device), | ||||||
|  |             torch.randint( | ||||||
|  |                 low=0, high=2, size=(256, 1), dtype=torch.bool, device=device | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         actual = torch.compile(model, fullgraph=True)(*inputs) | ||||||
|  |         expected = model(*inputs) | ||||||
|  |         torch.testing.assert_close(actual, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
| instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) | instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) | ||||||
|  |  | ||||||
|  | |||||||
| @ -1667,9 +1667,9 @@ def native_layer_norm_backward( | |||||||
|  |  | ||||||
|     N = prod(inner_dims)  # type: ignore[arg-type] |     N = prod(inner_dims)  # type: ignore[arg-type] | ||||||
|     M = prod(outer_dims)  # type: ignore[arg-type] |     M = prod(outer_dims)  # type: ignore[arg-type] | ||||||
|     from torch.fx.experimental.symbolic_shapes import guard_size_oblivious |     from torch.fx.experimental.symbolic_shapes import statically_known_true | ||||||
|  |  | ||||||
|     if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): |     if statically_known_true(M == 0) or statically_known_true(N == 0): | ||||||
|         return ( |         return ( | ||||||
|             input.new_zeros(input_shape) if output_mask[0] else None, |             input.new_zeros(input_shape) if output_mask[0] else None, | ||||||
|             input.new_zeros(input_shape[axis:]) if output_mask[1] else None, |             input.new_zeros(input_shape[axis:]) if output_mask[1] else None, | ||||||
|  | |||||||
| @ -2876,7 +2876,7 @@ class ExpandView(BaseView): | |||||||
|                 assert old_size[i] is not None |                 assert old_size[i] is not None | ||||||
|                 new_size[i] = old_size[i] |                 new_size[i] = old_size[i] | ||||||
|             elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr( |             elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr( | ||||||
|                 sympy.Eq(old_size[i], 1), size_oblivious=True |                 sympy.Eq(old_size[i], 1), fallback_value=False | ||||||
|             ): |             ): | ||||||
|                 pass |                 pass | ||||||
|             else: |             else: | ||||||
| @ -2903,7 +2903,7 @@ class ExpandView(BaseView): | |||||||
|                 new_stride.append( |                 new_stride.append( | ||||||
|                     stride |                     stride | ||||||
|                     if not V.graph.sizevars.shape_env.evaluate_expr( |                     if not V.graph.sizevars.shape_env.evaluate_expr( | ||||||
|                         sympy.Eq(size, 1), size_oblivious=True |                         sympy.Eq(size, 1), fallback_value=False | ||||||
|                     ) |                     ) | ||||||
|                     else sympy.S.Zero |                     else sympy.S.Zero | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -489,11 +489,11 @@ def broadcast_symbolic_shapes(a, b): | |||||||
|     output = [] |     output = [] | ||||||
|     for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): |     for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): | ||||||
|         if V.graph.sizevars.shape_env.evaluate_expr( |         if V.graph.sizevars.shape_env.evaluate_expr( | ||||||
|             sympy.Eq(y, 1), size_oblivious=True |             sympy.Eq(y, 1), fallback_value=False | ||||||
|         ): |         ): | ||||||
|             output.append(x) |             output.append(x) | ||||||
|         elif V.graph.sizevars.shape_env.evaluate_expr( |         elif V.graph.sizevars.shape_env.evaluate_expr( | ||||||
|             sympy.Eq(x, 1), size_oblivious=True |             sympy.Eq(x, 1), fallback_value=False | ||||||
|         ): |         ): | ||||||
|             output.append(y) |             output.append(y) | ||||||
|         else: |         else: | ||||||
| @ -939,26 +939,14 @@ def broadcast_tensors(*inputs): | |||||||
|     outputs = [] |     outputs = [] | ||||||
|     for x in inputs: |     for x in inputs: | ||||||
|         sizes = x.get_size() |         sizes = x.get_size() | ||||||
|         if len(sizes) != len(target) or any( |  | ||||||
|             ( |         def is_length_one(size: sympy.Expr): | ||||||
|                 ( |             return V.graph.sizevars.shape_env.evaluate_expr( | ||||||
|                     V.graph.sizevars.shape_env.evaluate_expr( |                 sympy.Eq(size, 1), fallback_value=False | ||||||
|                         sympy.Eq(a, 1), size_oblivious=True |  | ||||||
|                     ) |  | ||||||
|                     and not V.graph.sizevars.shape_env.evaluate_expr( |  | ||||||
|                         sympy.Eq(b, 1), size_oblivious=True |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|                 or ( |  | ||||||
|                     not V.graph.sizevars.shape_env.evaluate_expr( |  | ||||||
|                         sympy.Eq(a, 1), size_oblivious=True |  | ||||||
|                     ) |  | ||||||
|                     and V.graph.sizevars.shape_env.evaluate_expr( |  | ||||||
|                         sympy.Eq(b, 1), size_oblivious=True |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             ) |             ) | ||||||
|             for a, b in zip(sizes, target) |  | ||||||
|  |         if len(sizes) != len(target) or any( | ||||||
|  |             is_length_one(a) != is_length_one(b) for a, b in zip(sizes, target) | ||||||
|         ): |         ): | ||||||
|             x = expand(x, target) |             x = expand(x, target) | ||||||
|         outputs.append(x) |         outputs.append(x) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user