mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] support linear & layer_norm unbacked (#155267)
### What - Use `statically_known_true` over `guard_size_oblivious` in cases where we're checking an optimization path. Otherwise, it will DDE and we can't take the safe/slower path. - For broadcast checks, use `fallback=False` if we encounter a DDE. Typically, unbackeds would be ≥2 and that falls inline with size-oblivious reasoning (i.e. when `size_oblivious=True`). ### Example DDE ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/lowering.py:488 in broadcast_symbolic_shapes) ``` ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/ir.py:2797 in create) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155267 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
be72bcf828
commit
a6b7bea244
@ -515,6 +515,37 @@ class TestUnbackedSymints(InductorTestCase):
|
||||
x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)
|
||||
torch.compile(fn, fullgraph=True)(x)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton")
|
||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||
def test_unbacked_linear_layer_norm_input(self, device):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(387, 128, bias=True, device=device)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(387, device=device)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(128, device=device)
|
||||
|
||||
def forward(self, x, mask):
|
||||
masked_select = x.masked_select(mask)
|
||||
view = masked_select.view(-1, 387)
|
||||
|
||||
linear = self.linear(view)
|
||||
layer_norm1 = self.layer_norm1(view)
|
||||
layer_norm2 = self.layer_norm2(linear)
|
||||
return linear, layer_norm1, layer_norm2
|
||||
|
||||
model = MyModel()
|
||||
inputs = (
|
||||
torch.randn((256, 387), dtype=torch.float, device=device),
|
||||
torch.randint(
|
||||
low=0, high=2, size=(256, 1), dtype=torch.bool, device=device
|
||||
),
|
||||
)
|
||||
|
||||
actual = torch.compile(model, fullgraph=True)(*inputs)
|
||||
expected = model(*inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
||||
|
||||
|
@ -1667,9 +1667,9 @@ def native_layer_norm_backward(
|
||||
|
||||
N = prod(inner_dims) # type: ignore[arg-type]
|
||||
M = prod(outer_dims) # type: ignore[arg-type]
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
|
||||
if statically_known_true(M == 0) or statically_known_true(N == 0):
|
||||
return (
|
||||
input.new_zeros(input_shape) if output_mask[0] else None,
|
||||
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
|
||||
|
@ -2876,7 +2876,7 @@ class ExpandView(BaseView):
|
||||
assert old_size[i] is not None
|
||||
new_size[i] = old_size[i]
|
||||
elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(old_size[i], 1), size_oblivious=True
|
||||
sympy.Eq(old_size[i], 1), fallback_value=False
|
||||
):
|
||||
pass
|
||||
else:
|
||||
@ -2903,7 +2903,7 @@ class ExpandView(BaseView):
|
||||
new_stride.append(
|
||||
stride
|
||||
if not V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(size, 1), size_oblivious=True
|
||||
sympy.Eq(size, 1), fallback_value=False
|
||||
)
|
||||
else sympy.S.Zero
|
||||
)
|
||||
|
@ -489,11 +489,11 @@ def broadcast_symbolic_shapes(a, b):
|
||||
output = []
|
||||
for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
|
||||
if V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(y, 1), size_oblivious=True
|
||||
sympy.Eq(y, 1), fallback_value=False
|
||||
):
|
||||
output.append(x)
|
||||
elif V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(x, 1), size_oblivious=True
|
||||
sympy.Eq(x, 1), fallback_value=False
|
||||
):
|
||||
output.append(y)
|
||||
else:
|
||||
@ -939,26 +939,14 @@ def broadcast_tensors(*inputs):
|
||||
outputs = []
|
||||
for x in inputs:
|
||||
sizes = x.get_size()
|
||||
if len(sizes) != len(target) or any(
|
||||
(
|
||||
(
|
||||
V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(a, 1), size_oblivious=True
|
||||
)
|
||||
and not V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(b, 1), size_oblivious=True
|
||||
)
|
||||
)
|
||||
or (
|
||||
not V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(a, 1), size_oblivious=True
|
||||
)
|
||||
and V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(b, 1), size_oblivious=True
|
||||
)
|
||||
)
|
||||
|
||||
def is_length_one(size: sympy.Expr):
|
||||
return V.graph.sizevars.shape_env.evaluate_expr(
|
||||
sympy.Eq(size, 1), fallback_value=False
|
||||
)
|
||||
for a, b in zip(sizes, target)
|
||||
|
||||
if len(sizes) != len(target) or any(
|
||||
is_length_one(a) != is_length_one(b) for a, b in zip(sizes, target)
|
||||
):
|
||||
x = expand(x, target)
|
||||
outputs.append(x)
|
||||
|
Reference in New Issue
Block a user