mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ARM] Fix test_float_to_int_conversion_nonfinite
(#145367)
We have broken tests on Aarch64 which are not enabled upstream, this PR will fix and enable those tests. ``` AssertionError: Tensor-likes are not equal! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 1 at index (1,) Greatest relative difference: 1.0842021724855044e-19 at index (1,) To execute this test, run the following from the base repo dir: python test/test_tensor_creation_ops.py TestTensorCreationCPU.test_float_to_int_conversion_nonfinite_cpu_int64 This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145367 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
a20055288f
commit
f59a56e56f
@ -1438,7 +1438,7 @@ test_executorch() {
|
||||
test_linux_aarch64() {
|
||||
python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \
|
||||
test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \
|
||||
test_foreach test_reductions test_unary_ufuncs \
|
||||
test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops \
|
||||
--shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
|
||||
|
||||
# Dynamo tests
|
||||
|
@ -33,6 +33,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
IS_SANDCASTLE,
|
||||
IS_S390X,
|
||||
IS_ARM64,
|
||||
parametrize,
|
||||
skipIfTorchDynamo,
|
||||
xfailIfTorchDynamo,
|
||||
@ -1088,13 +1089,18 @@ class TestTensorCreation(TestCase):
|
||||
@dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
|
||||
def test_float_to_int_conversion_nonfinite(self, device, dtype):
|
||||
vals = (float('-inf'), float('inf'), float('nan'))
|
||||
refs = 0
|
||||
if dtype == torch.bool:
|
||||
refs = True
|
||||
elif dtype in (torch.int32, torch.int64):
|
||||
refs = torch.iinfo(dtype).min
|
||||
|
||||
self._float_to_int_conversion_helper(vals, device, dtype, (refs, ) * 3)
|
||||
if dtype == torch.bool:
|
||||
refs = (True, True, True)
|
||||
elif IS_ARM64:
|
||||
refs = (torch.iinfo(dtype).min, torch.iinfo(dtype).max, 0)
|
||||
if dtype in (torch.int8, torch.int16):
|
||||
refs = (0, -1, 0)
|
||||
else:
|
||||
refs = (0, 0, 0)
|
||||
if dtype in (torch.int32, torch.int64):
|
||||
refs = (torch.iinfo(dtype).min, ) * 3
|
||||
self._float_to_int_conversion_helper(vals, device, dtype, refs)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_complex_type_conversions(self, device):
|
||||
|
Reference in New Issue
Block a user