[inductor] Add shape checks to ExpandView (#113839)

Currently `ExpandView` doesn't check that the expanded shape is valid which may
allow bugs to slip through which cause silent correctness issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113839
Approved by: https://github.com/ezyang
This commit is contained in:
Peter Bell
2024-01-03 17:55:29 +00:00
committed by PyTorch MergeBot
parent 1c69d0bdb5
commit f6be25bae6
3 changed files with 35 additions and 0 deletions

View File

@ -4,6 +4,7 @@ import torch
from torch._dynamo import config as dynamo_config
from torch._inductor import config as inductor_config
from torch.testing import make_tensor
from torch.testing._internal.common_utils import IS_LINUX, TestCase as TorchTestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
@ -30,6 +31,16 @@ class TestUnbackedSymints(TorchTestCase):
torch.testing.assert_close(actual, expected)
def test_expand_mismatch(self):
def fn(x):
nz = x.nonzero()
return nz.expand([-1, 128])
x = make_tensor(32, 4, device="cpu", dtype=torch.float32, exclude_zero=True)
with dynamo_config.patch({"capture_dynamic_output_shape_ops": True}):
with self.assertRaises(torch._dynamo.exc.TorchRuntimeError):
actual = torch.compile(fn, fullgraph=True)(x)
def test_autotuning(self):
def fn(x, y):
nz = torch.nonzero(x)

View File

@ -1839,6 +1839,15 @@ class ExpandView(BaseView):
if new_size[i] == -1:
assert old_size[i] is not None
new_size[i] = old_size[i]
elif old_size[i] is None or old_size[i] == 1:
pass
else:
# Expect broadcast compatibility
new_size[i] = V.graph.sizevars.expect_equals(
new_size[i],
old_size[i],
msg=f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}",
)
return new_size
@classmethod

View File

@ -328,6 +328,21 @@ class SizeVarAllocator:
def guard_lt(self, left: Expr, right: Expr) -> None:
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
def expect_true(self, expr: Expr, *, msg: str) -> None:
expr = sympy_subs(expr, self.inv_precomputed_replacements)
self.shape_env.defer_runtime_assert(expr, msg, fx_node=V.graph.current_node)
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
# Prefer returning the expression without unbacked symints
if self.shape_env.is_unbacked_symint(left):
self.expect_true(sympy.Eq(left, right), msg=msg)
return right
elif self.shape_env.is_unbacked_symint(right):
self.expect_true(sympy.Eq(left, right), msg=msg)
return left
else:
return self.guard_equals(left, right)
# The evaluate functions evaluate some symbolic sympy expression
# (NB: not necessarily an Expr) and return what the concrete result
# is, guarding on the expression being that result