mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Add shape checks to ExpandView (#113839)
Currently `ExpandView` doesn't check that the expanded shape is valid which may allow bugs to slip through which cause silent correctness issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113839 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c69d0bdb5
commit
f6be25bae6
@ -4,6 +4,7 @@ import torch
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch.testing import make_tensor
|
||||
|
||||
from torch.testing._internal.common_utils import IS_LINUX, TestCase as TorchTestCase
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
@ -30,6 +31,16 @@ class TestUnbackedSymints(TorchTestCase):
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
def test_expand_mismatch(self):
|
||||
def fn(x):
|
||||
nz = x.nonzero()
|
||||
return nz.expand([-1, 128])
|
||||
|
||||
x = make_tensor(32, 4, device="cpu", dtype=torch.float32, exclude_zero=True)
|
||||
with dynamo_config.patch({"capture_dynamic_output_shape_ops": True}):
|
||||
with self.assertRaises(torch._dynamo.exc.TorchRuntimeError):
|
||||
actual = torch.compile(fn, fullgraph=True)(x)
|
||||
|
||||
def test_autotuning(self):
|
||||
def fn(x, y):
|
||||
nz = torch.nonzero(x)
|
||||
|
||||
@ -1839,6 +1839,15 @@ class ExpandView(BaseView):
|
||||
if new_size[i] == -1:
|
||||
assert old_size[i] is not None
|
||||
new_size[i] = old_size[i]
|
||||
elif old_size[i] is None or old_size[i] == 1:
|
||||
pass
|
||||
else:
|
||||
# Expect broadcast compatibility
|
||||
new_size[i] = V.graph.sizevars.expect_equals(
|
||||
new_size[i],
|
||||
old_size[i],
|
||||
msg=f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}",
|
||||
)
|
||||
return new_size
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -328,6 +328,21 @@ class SizeVarAllocator:
|
||||
def guard_lt(self, left: Expr, right: Expr) -> None:
|
||||
assert self.shape_env.evaluate_expr(sympy.Lt(left, right))
|
||||
|
||||
def expect_true(self, expr: Expr, *, msg: str) -> None:
|
||||
expr = sympy_subs(expr, self.inv_precomputed_replacements)
|
||||
self.shape_env.defer_runtime_assert(expr, msg, fx_node=V.graph.current_node)
|
||||
|
||||
def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr:
|
||||
# Prefer returning the expression without unbacked symints
|
||||
if self.shape_env.is_unbacked_symint(left):
|
||||
self.expect_true(sympy.Eq(left, right), msg=msg)
|
||||
return right
|
||||
elif self.shape_env.is_unbacked_symint(right):
|
||||
self.expect_true(sympy.Eq(left, right), msg=msg)
|
||||
return left
|
||||
else:
|
||||
return self.guard_equals(left, right)
|
||||
|
||||
# The evaluate functions evaluate some symbolic sympy expression
|
||||
# (NB: not necessarily an Expr) and return what the concrete result
|
||||
# is, guarding on the expression being that result
|
||||
|
||||
Reference in New Issue
Block a user