[ARM] Fix test_float_to_int_conversion_nonfinite (#145367)

We have broken tests on Aarch64 which are not enabled upstream, this PR will fix and enable those tests.

```
AssertionError: Tensor-likes are not equal!

Mismatched elements: 2 / 3 (66.7%)
Greatest absolute difference: 1 at index (1,)
Greatest relative difference: 1.0842021724855044e-19 at index (1,)

To execute this test, run the following from the base repo dir:
    python test/test_tensor_creation_ops.py TestTensorCreationCPU.test_float_to_int_conversion_nonfinite_cpu_int64

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145367
Approved by: https://github.com/malfet
This commit is contained in:
Robert Hardwick
2025-02-11 22:22:10 +00:00
committed by PyTorch MergeBot
parent a20055288f
commit f59a56e56f
2 changed files with 13 additions and 7 deletions

View File

@ -1438,7 +1438,7 @@ test_executorch() {
test_linux_aarch64() {
python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \
test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \
test_foreach test_reductions test_unary_ufuncs \
test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops \
--shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
# Dynamo tests

View File

@ -33,6 +33,7 @@ from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_SANDCASTLE,
IS_S390X,
IS_ARM64,
parametrize,
skipIfTorchDynamo,
xfailIfTorchDynamo,
@ -1088,13 +1089,18 @@ class TestTensorCreation(TestCase):
@dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
def test_float_to_int_conversion_nonfinite(self, device, dtype):
vals = (float('-inf'), float('inf'), float('nan'))
refs = 0
if dtype == torch.bool:
refs = True
elif dtype in (torch.int32, torch.int64):
refs = torch.iinfo(dtype).min
self._float_to_int_conversion_helper(vals, device, dtype, (refs, ) * 3)
if dtype == torch.bool:
refs = (True, True, True)
elif IS_ARM64:
refs = (torch.iinfo(dtype).min, torch.iinfo(dtype).max, 0)
if dtype in (torch.int8, torch.int16):
refs = (0, -1, 0)
else:
refs = (0, 0, 0)
if dtype in (torch.int32, torch.int64):
refs = (torch.iinfo(dtype).min, ) * 3
self._float_to_int_conversion_helper(vals, device, dtype, refs)
@onlyNativeDeviceTypes
def test_complex_type_conversions(self, device):