diff --git a/test/test_linalg.py b/test/test_linalg.py index 649c46b5404c..ea841a3b18fc 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -14,6 +14,8 @@ import random from random import randrange from itertools import product from functools import reduce, partial +from typing import Union, Optional +from torch._prims_common import DimsType from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, @@ -1475,6 +1477,61 @@ class TestLinalg(TestCase): keepdim, norm_dtype) + + def test_vector_norm_decom_unbacked_checks(self): + from torch._refs.linalg import _check_vector_norm_args + + class Mod(torch.nn.Module): + def __init__(self, ord, dim): + super().__init__() + self.ord = ord + self.dim = dim + + def forward(self, a): + x = a.item() + tensor_unbacked_size = torch.ones(x, x + 1, x + 2) + _check_vector_norm_args(tensor_unbacked_size, self.ord, self.dim) + return tensor_unbacked_size + + def test( + ord: Union[float, int], + dim: Optional[DimsType], + expect_numel_runtime_check: bool, + expect_index_0_check: bool = False, + ) -> None: + m = Mod(ord, dim) + exported_program: torch.export.ExportedProgram = torch.export.export( + m, args=tuple(torch.tensor([1])) + ) + self.assertEqual( + "Runtime assertion failed for expression Ne(u0*(u0 + 1)*(u0 + 2), 0)" + in exported_program.graph_module.code, + expect_numel_runtime_check, + ) + self.assertEqual( + "Runtime assertion failed for expression Ne(u0, 0) | Ne(u0*(u0 + 1)*(u0 + 2), 0)" + in exported_program.graph_module.code, + expect_index_0_check, + ) + + # dim is int + test(-1, 1, True) + + # dim is None + test(-1, None, True) + + # len(dim) == 0 + test(-1, [], True) + + # shape[d] == 0 + test(-1, [0], False, True) + + # u0 + 1 == 0 is False we do not see a runtime assert in the generated graph. + test(-1, [1], False, False) + + test(-1, [0, 1], False, True) + test(-1, [0, 0], False, True) + def test_vector_norm_dim_tuple_arg(self, device): test_cases = [ # input size, dim, error, error message diff --git a/torch/__init__.py b/torch/__init__.py index f0b20bb8c7c3..3ac4c52e4380 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2099,7 +2099,7 @@ for __name in dir(_C._VariableFunctions): __obj.__module__ = __name__ # "torch" # Hide some APIs that should not be public if __name == "segment_reduce": - # TODO: Once the undocumented FC window is passed, remove the line bellow + # TODO: Once the undocumented FC window is passed, remove the line below globals()[__name] = __obj __name = "_" + __name globals()[__name] = __obj diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index c85962f22842..f3aca0d776c2 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -97,6 +97,34 @@ def diagonal( return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) +def _check_vector_norm_args( + x: TensorLikeType, ord: Union[float, int] = 2, dim: Optional[DimsType] = None +): + from torch.fx.experimental.symbolic_shapes import sym_or + + if not (ord < 0.0 or ord == float("inf")): + return + + torch._check( + sym_or( + x.numel() != 0, + not isinstance(dim, IntLike) and dim is not None and len(dim) != 0, + ), + "linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + + shape = x.shape + if dim is not None and not isinstance(dim, IntLike): + for d in dim: + torch._check( + sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0), + "linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + + @register_decomposition(torch._ops.ops.aten.linalg_vector_norm) @out_wrapper(exact_dtype=True) def vector_norm( @@ -107,29 +135,13 @@ def vector_norm( *, dtype: Optional[torch.dtype] = None, ) -> Tensor: - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - - # Checks check_fp_or_complex(x.dtype, "linalg.vector_norm") if isinstance(dim, Dim): dim = [dim] # type: ignore[assignment] - if guard_size_oblivious(x.numel() == 0) and (ord < 0.0 or ord == float("inf")): - torch._check( - dim is not None and len(dim) != 0, - lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " - "because the operation does not have an identity", - ) - shape = x.shape - assert dim is not None # mypy does not seem to be able to see through check? - for d in dim: - torch._check( - shape[d] != 0, - lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " - f"dimension {d} because this dimension is empty and the " - "operation does not have an identity", - ) + _check_vector_norm_args(x, ord, dim) + _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") computation_dtype, result_dtype = utils.reduction_dtypes( diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 6ac25ceceff4..c6ba99f65ffd 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1299,6 +1299,19 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool: return result +# When a or b is evaluated, a is evaluated eagerly first then b. This causes +# a data dependent error for an expression “if u0==1 or True”. or over guarding for +# “if s0==1 or True”. + +# On the other hand, when we use operator.or_, then dynamo will generate +# a sympy expression Sympy.Or(u0==1, True) without evaluating the args first. + +# When the whole expression is passed to evaluation in that case, we do not throw a +# data dependent error or guard because we can statically know the result is True +# before unpacking the symbols. +sym_or = operator.or_ + + def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]: """ Like ==, but when run on list/tuple, it will recursively test equality