Fix Dynamo tests failing with "Failed running call_function <built-in function linalg_norm" (#120993)

When iterating the ord value through an array, we are sharing the same torchdynamo context. This makes dynamo treat the `ord` variable as dynamic shape, causing problems.

In the `vector_norm` decomposition, casting the int type ord to float will fix this problem.

Fixes https://github.com/pytorch/pytorch/issues/119795
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120993
Approved by: https://github.com/lezcano
This commit is contained in:
Xu Zhao
2024-03-01 20:27:37 +00:00
committed by PyTorch MergeBot
parent 39e4d1a535
commit 7a64eb65e4
10 changed files with 5 additions and 14 deletions

View File

@ -1895,13 +1895,6 @@ symbolic_tensor_failures = {
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
# AssertionError: False != True - https://github.com/pytorch/pytorch/issues/113905
xfail('dist', ''),
xfail('norm', ''),
xfail('linalg.vector_norm', ''),
xfail('linalg.norm', 'subgradients_at_zero'),
xfail('renorm', ''),
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
# many complex operators incorrect striding, metadata
@ -1928,10 +1921,6 @@ symbolic_tensor_segfaults = {
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
outplace_symbolic_tensor_failures = {
xfail('linalg.norm', ''),
}
inplace_symbolic_tensor_failures = {
# bugs
xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double
@ -2029,7 +2018,7 @@ class TestProxyTensorOpInfo(TestCase):
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures)
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "symbolic")

View File

@ -16,6 +16,7 @@ from torch._prims_common import (
Dim,
DimsType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
IntLike,
NumberType,
TensorLikeType,
)
@ -101,7 +102,7 @@ def diagonal(
@out_wrapper(exact_dtype=True)
def vector_norm(
x: TensorLikeType,
ord: float = 2.0,
ord: Union[float, int] = 2,
dim: Optional[DimsType] = None,
keepdim: bool = False,
*,
@ -148,7 +149,8 @@ def vector_norm(
x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment]
reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim)
if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)):
is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0
if not (is_ord_even and utils.is_float_dtype(x.dtype)):
x = torch.abs(x)
return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value]