mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] support linear & layer_norm unbacked (#155267)
### What - Use `statically_known_true` over `guard_size_oblivious` in cases where we're checking an optimization path. Otherwise, it will DDE and we can't take the safe/slower path. - For broadcast checks, use `fallback=False` if we encounter a DDE. Typically, unbackeds would be ≥2 and that falls inline with size-oblivious reasoning (i.e. when `size_oblivious=True`). ### Example DDE ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/lowering.py:488 in broadcast_symbolic_shapes) ``` ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/ir.py:2797 in create) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155267 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
be72bcf828
commit
a6b7bea244
@ -515,6 +515,37 @@ class TestUnbackedSymints(InductorTestCase):
|
|||||||
x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)
|
x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)
|
||||||
torch.compile(fn, fullgraph=True)(x)
|
torch.compile(fn, fullgraph=True)(x)
|
||||||
|
|
||||||
|
@skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton")
|
||||||
|
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||||
|
def test_unbacked_linear_layer_norm_input(self, device):
|
||||||
|
class MyModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(387, 128, bias=True, device=device)
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(387, device=device)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(128, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
masked_select = x.masked_select(mask)
|
||||||
|
view = masked_select.view(-1, 387)
|
||||||
|
|
||||||
|
linear = self.linear(view)
|
||||||
|
layer_norm1 = self.layer_norm1(view)
|
||||||
|
layer_norm2 = self.layer_norm2(linear)
|
||||||
|
return linear, layer_norm1, layer_norm2
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
inputs = (
|
||||||
|
torch.randn((256, 387), dtype=torch.float, device=device),
|
||||||
|
torch.randint(
|
||||||
|
low=0, high=2, size=(256, 1), dtype=torch.bool, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
actual = torch.compile(model, fullgraph=True)(*inputs)
|
||||||
|
expected = model(*inputs)
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
||||||
|
|
||||||
|
|||||||
@ -1667,9 +1667,9 @@ def native_layer_norm_backward(
|
|||||||
|
|
||||||
N = prod(inner_dims) # type: ignore[arg-type]
|
N = prod(inner_dims) # type: ignore[arg-type]
|
||||||
M = prod(outer_dims) # type: ignore[arg-type]
|
M = prod(outer_dims) # type: ignore[arg-type]
|
||||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||||
|
|
||||||
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
|
if statically_known_true(M == 0) or statically_known_true(N == 0):
|
||||||
return (
|
return (
|
||||||
input.new_zeros(input_shape) if output_mask[0] else None,
|
input.new_zeros(input_shape) if output_mask[0] else None,
|
||||||
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
|
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
|
||||||
|
|||||||
@ -2876,7 +2876,7 @@ class ExpandView(BaseView):
|
|||||||
assert old_size[i] is not None
|
assert old_size[i] is not None
|
||||||
new_size[i] = old_size[i]
|
new_size[i] = old_size[i]
|
||||||
elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr(
|
elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr(
|
||||||
sympy.Eq(old_size[i], 1), size_oblivious=True
|
sympy.Eq(old_size[i], 1), fallback_value=False
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -2903,7 +2903,7 @@ class ExpandView(BaseView):
|
|||||||
new_stride.append(
|
new_stride.append(
|
||||||
stride
|
stride
|
||||||
if not V.graph.sizevars.shape_env.evaluate_expr(
|
if not V.graph.sizevars.shape_env.evaluate_expr(
|
||||||
sympy.Eq(size, 1), size_oblivious=True
|
sympy.Eq(size, 1), fallback_value=False
|
||||||
)
|
)
|
||||||
else sympy.S.Zero
|
else sympy.S.Zero
|
||||||
)
|
)
|
||||||
|
|||||||
@ -489,11 +489,11 @@ def broadcast_symbolic_shapes(a, b):
|
|||||||
output = []
|
output = []
|
||||||
for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
|
for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
|
||||||
if V.graph.sizevars.shape_env.evaluate_expr(
|
if V.graph.sizevars.shape_env.evaluate_expr(
|
||||||
sympy.Eq(y, 1), size_oblivious=True
|
sympy.Eq(y, 1), fallback_value=False
|
||||||
):
|
):
|
||||||
output.append(x)
|
output.append(x)
|
||||||
elif V.graph.sizevars.shape_env.evaluate_expr(
|
elif V.graph.sizevars.shape_env.evaluate_expr(
|
||||||
sympy.Eq(x, 1), size_oblivious=True
|
sympy.Eq(x, 1), fallback_value=False
|
||||||
):
|
):
|
||||||
output.append(y)
|
output.append(y)
|
||||||
else:
|
else:
|
||||||
@ -939,26 +939,14 @@ def broadcast_tensors(*inputs):
|
|||||||
outputs = []
|
outputs = []
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
sizes = x.get_size()
|
sizes = x.get_size()
|
||||||
if len(sizes) != len(target) or any(
|
|
||||||
(
|
def is_length_one(size: sympy.Expr):
|
||||||
(
|
return V.graph.sizevars.shape_env.evaluate_expr(
|
||||||
V.graph.sizevars.shape_env.evaluate_expr(
|
sympy.Eq(size, 1), fallback_value=False
|
||||||
sympy.Eq(a, 1), size_oblivious=True
|
|
||||||
)
|
|
||||||
and not V.graph.sizevars.shape_env.evaluate_expr(
|
|
||||||
sympy.Eq(b, 1), size_oblivious=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
not V.graph.sizevars.shape_env.evaluate_expr(
|
|
||||||
sympy.Eq(a, 1), size_oblivious=True
|
|
||||||
)
|
|
||||||
and V.graph.sizevars.shape_env.evaluate_expr(
|
|
||||||
sympy.Eq(b, 1), size_oblivious=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
for a, b in zip(sizes, target)
|
|
||||||
|
if len(sizes) != len(target) or any(
|
||||||
|
is_length_one(a) != is_length_one(b) for a, b in zip(sizes, target)
|
||||||
):
|
):
|
||||||
x = expand(x, target)
|
x = expand(x, target)
|
||||||
outputs.append(x)
|
outputs.append(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user