mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Remove guard_size_oblivious from vector_norm decomposition. (#148809)
This PR remove the usage of guard_size_oblivious in vector_norm by inlining it in the runtime check, this prevent any data dependent error from ever appearing here at the locations where guard_size_oblivious used to exist. Before this PR it used to break potentially. This is NOT BC breaking or changing of semantics from eager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148809 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
e6969c1bd8
commit
5471e80fb4
@ -14,6 +14,8 @@ import random
|
||||
from random import randrange
|
||||
from itertools import product
|
||||
from functools import reduce, partial
|
||||
from typing import Union, Optional
|
||||
from torch._prims_common import DimsType
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
|
||||
@ -1475,6 +1477,61 @@ class TestLinalg(TestCase):
|
||||
keepdim,
|
||||
norm_dtype)
|
||||
|
||||
|
||||
def test_vector_norm_decom_unbacked_checks(self):
|
||||
from torch._refs.linalg import _check_vector_norm_args
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self, ord, dim):
|
||||
super().__init__()
|
||||
self.ord = ord
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, a):
|
||||
x = a.item()
|
||||
tensor_unbacked_size = torch.ones(x, x + 1, x + 2)
|
||||
_check_vector_norm_args(tensor_unbacked_size, self.ord, self.dim)
|
||||
return tensor_unbacked_size
|
||||
|
||||
def test(
|
||||
ord: Union[float, int],
|
||||
dim: Optional[DimsType],
|
||||
expect_numel_runtime_check: bool,
|
||||
expect_index_0_check: bool = False,
|
||||
) -> None:
|
||||
m = Mod(ord, dim)
|
||||
exported_program: torch.export.ExportedProgram = torch.export.export(
|
||||
m, args=tuple(torch.tensor([1]))
|
||||
)
|
||||
self.assertEqual(
|
||||
"Runtime assertion failed for expression Ne(u0*(u0 + 1)*(u0 + 2), 0)"
|
||||
in exported_program.graph_module.code,
|
||||
expect_numel_runtime_check,
|
||||
)
|
||||
self.assertEqual(
|
||||
"Runtime assertion failed for expression Ne(u0, 0) | Ne(u0*(u0 + 1)*(u0 + 2), 0)"
|
||||
in exported_program.graph_module.code,
|
||||
expect_index_0_check,
|
||||
)
|
||||
|
||||
# dim is int
|
||||
test(-1, 1, True)
|
||||
|
||||
# dim is None
|
||||
test(-1, None, True)
|
||||
|
||||
# len(dim) == 0
|
||||
test(-1, [], True)
|
||||
|
||||
# shape[d] == 0
|
||||
test(-1, [0], False, True)
|
||||
|
||||
# u0 + 1 == 0 is False we do not see a runtime assert in the generated graph.
|
||||
test(-1, [1], False, False)
|
||||
|
||||
test(-1, [0, 1], False, True)
|
||||
test(-1, [0, 0], False, True)
|
||||
|
||||
def test_vector_norm_dim_tuple_arg(self, device):
|
||||
test_cases = [
|
||||
# input size, dim, error, error message
|
||||
|
@ -2099,7 +2099,7 @@ for __name in dir(_C._VariableFunctions):
|
||||
__obj.__module__ = __name__ # "torch"
|
||||
# Hide some APIs that should not be public
|
||||
if __name == "segment_reduce":
|
||||
# TODO: Once the undocumented FC window is passed, remove the line bellow
|
||||
# TODO: Once the undocumented FC window is passed, remove the line below
|
||||
globals()[__name] = __obj
|
||||
__name = "_" + __name
|
||||
globals()[__name] = __obj
|
||||
|
@ -97,6 +97,34 @@ def diagonal(
|
||||
return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2)
|
||||
|
||||
|
||||
def _check_vector_norm_args(
|
||||
x: TensorLikeType, ord: Union[float, int] = 2, dim: Optional[DimsType] = None
|
||||
):
|
||||
from torch.fx.experimental.symbolic_shapes import sym_or
|
||||
|
||||
if not (ord < 0.0 or ord == float("inf")):
|
||||
return
|
||||
|
||||
torch._check(
|
||||
sym_or(
|
||||
x.numel() != 0,
|
||||
not isinstance(dim, IntLike) and dim is not None and len(dim) != 0,
|
||||
),
|
||||
"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
||||
"because the operation does not have an identity",
|
||||
)
|
||||
|
||||
shape = x.shape
|
||||
if dim is not None and not isinstance(dim, IntLike):
|
||||
for d in dim:
|
||||
torch._check(
|
||||
sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0),
|
||||
"linalg.vector_norm cannot compute the {ord} norm on the "
|
||||
f"dimension {d} because this dimension is empty and the "
|
||||
"operation does not have an identity",
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(torch._ops.ops.aten.linalg_vector_norm)
|
||||
@out_wrapper(exact_dtype=True)
|
||||
def vector_norm(
|
||||
@ -107,29 +135,13 @@ def vector_norm(
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Tensor:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
# Checks
|
||||
check_fp_or_complex(x.dtype, "linalg.vector_norm")
|
||||
|
||||
if isinstance(dim, Dim):
|
||||
dim = [dim] # type: ignore[assignment]
|
||||
|
||||
if guard_size_oblivious(x.numel() == 0) and (ord < 0.0 or ord == float("inf")):
|
||||
torch._check(
|
||||
dim is not None and len(dim) != 0,
|
||||
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
||||
"because the operation does not have an identity",
|
||||
)
|
||||
shape = x.shape
|
||||
assert dim is not None # mypy does not seem to be able to see through check?
|
||||
for d in dim:
|
||||
torch._check(
|
||||
shape[d] != 0,
|
||||
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
|
||||
f"dimension {d} because this dimension is empty and the "
|
||||
"operation does not have an identity",
|
||||
)
|
||||
_check_vector_norm_args(x, ord, dim)
|
||||
|
||||
_check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")
|
||||
|
||||
computation_dtype, result_dtype = utils.reduction_dtypes(
|
||||
|
@ -1299,6 +1299,19 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
|
||||
return result
|
||||
|
||||
|
||||
# When a or b is evaluated, a is evaluated eagerly first then b. This causes
|
||||
# a data dependent error for an expression “if u0==1 or True”. or over guarding for
|
||||
# “if s0==1 or True”.
|
||||
|
||||
# On the other hand, when we use operator.or_, then dynamo will generate
|
||||
# a sympy expression Sympy.Or(u0==1, True) without evaluating the args first.
|
||||
|
||||
# When the whole expression is passed to evaluation in that case, we do not throw a
|
||||
# data dependent error or guard because we can statically know the result is True
|
||||
# before unpacking the symbols.
|
||||
sym_or = operator.or_
|
||||
|
||||
|
||||
def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]:
|
||||
"""
|
||||
Like ==, but when run on list/tuple, it will recursively test equality
|
||||
|
Reference in New Issue
Block a user