Remove guard_size_oblivious from vector_norm decomposition. (#148809)

This PR remove the usage of guard_size_oblivious in vector_norm by inlining it in the runtime check,
this prevent any data dependent error from ever appearing here at the locations where guard_size_oblivious
used to exist. Before this PR it used to break potentially. This is NOT BC breaking or changing of semantics from eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148809
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-04-03 18:26:20 -07:00
committed by PyTorch MergeBot
parent e6969c1bd8
commit 5471e80fb4
4 changed files with 101 additions and 19 deletions

View File

@ -14,6 +14,8 @@ import random
from random import randrange
from itertools import product
from functools import reduce, partial
from typing import Union, Optional
from torch._prims_common import DimsType
from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
@ -1475,6 +1477,61 @@ class TestLinalg(TestCase):
keepdim,
norm_dtype)
def test_vector_norm_decom_unbacked_checks(self):
from torch._refs.linalg import _check_vector_norm_args
class Mod(torch.nn.Module):
def __init__(self, ord, dim):
super().__init__()
self.ord = ord
self.dim = dim
def forward(self, a):
x = a.item()
tensor_unbacked_size = torch.ones(x, x + 1, x + 2)
_check_vector_norm_args(tensor_unbacked_size, self.ord, self.dim)
return tensor_unbacked_size
def test(
ord: Union[float, int],
dim: Optional[DimsType],
expect_numel_runtime_check: bool,
expect_index_0_check: bool = False,
) -> None:
m = Mod(ord, dim)
exported_program: torch.export.ExportedProgram = torch.export.export(
m, args=tuple(torch.tensor([1]))
)
self.assertEqual(
"Runtime assertion failed for expression Ne(u0*(u0 + 1)*(u0 + 2), 0)"
in exported_program.graph_module.code,
expect_numel_runtime_check,
)
self.assertEqual(
"Runtime assertion failed for expression Ne(u0, 0) | Ne(u0*(u0 + 1)*(u0 + 2), 0)"
in exported_program.graph_module.code,
expect_index_0_check,
)
# dim is int
test(-1, 1, True)
# dim is None
test(-1, None, True)
# len(dim) == 0
test(-1, [], True)
# shape[d] == 0
test(-1, [0], False, True)
# u0 + 1 == 0 is False we do not see a runtime assert in the generated graph.
test(-1, [1], False, False)
test(-1, [0, 1], False, True)
test(-1, [0, 0], False, True)
def test_vector_norm_dim_tuple_arg(self, device):
test_cases = [
# input size, dim, error, error message

View File

@ -2099,7 +2099,7 @@ for __name in dir(_C._VariableFunctions):
__obj.__module__ = __name__ # "torch"
# Hide some APIs that should not be public
if __name == "segment_reduce":
# TODO: Once the undocumented FC window is passed, remove the line bellow
# TODO: Once the undocumented FC window is passed, remove the line below
globals()[__name] = __obj
__name = "_" + __name
globals()[__name] = __obj

View File

@ -97,6 +97,34 @@ def diagonal(
return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2)
def _check_vector_norm_args(
x: TensorLikeType, ord: Union[float, int] = 2, dim: Optional[DimsType] = None
):
from torch.fx.experimental.symbolic_shapes import sym_or
if not (ord < 0.0 or ord == float("inf")):
return
torch._check(
sym_or(
x.numel() != 0,
not isinstance(dim, IntLike) and dim is not None and len(dim) != 0,
),
"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
"because the operation does not have an identity",
)
shape = x.shape
if dim is not None and not isinstance(dim, IntLike):
for d in dim:
torch._check(
sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0),
"linalg.vector_norm cannot compute the {ord} norm on the "
f"dimension {d} because this dimension is empty and the "
"operation does not have an identity",
)
@register_decomposition(torch._ops.ops.aten.linalg_vector_norm)
@out_wrapper(exact_dtype=True)
def vector_norm(
@ -107,29 +135,13 @@ def vector_norm(
*,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
# Checks
check_fp_or_complex(x.dtype, "linalg.vector_norm")
if isinstance(dim, Dim):
dim = [dim] # type: ignore[assignment]
if guard_size_oblivious(x.numel() == 0) and (ord < 0.0 or ord == float("inf")):
torch._check(
dim is not None and len(dim) != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
"because the operation does not have an identity",
)
shape = x.shape
assert dim is not None # mypy does not seem to be able to see through check?
for d in dim:
torch._check(
shape[d] != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
f"dimension {d} because this dimension is empty and the "
"operation does not have an identity",
)
_check_vector_norm_args(x, ord, dim)
_check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")
computation_dtype, result_dtype = utils.reduction_dtypes(

View File

@ -1299,6 +1299,19 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
return result
# When a or b is evaluated, a is evaluated eagerly first then b. This causes
# a data dependent error for an expression “if u0==1 or True”. or over guarding for
# “if s0==1 or True”.
# On the other hand, when we use operator.or_, then dynamo will generate
# a sympy expression Sympy.Or(u0==1, True) without evaluating the args first.
# When the whole expression is passed to evaluation in that case, we do not throw a
# data dependent error or guard because we can statically know the result is True
# before unpacking the symbols.
sym_or = operator.or_
def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]:
"""
Like ==, but when run on list/tuple, it will recursively test equality