mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	fix test_float_to_int_conversion_nonfinite for NumPy 2 (#138131)
Related to #107302 We saw `test_float_to_int_conversion_nonfinite` failed as we upgrade to NumPy 2. It is caused by the undefined behavior of `numpy` casting `inf`, `-inf` and `nan` from `np.float32` to other dtypes. The test is using NumPy as reference for the ground truth. (see line 1013-1015) However, these behaviors are undefined in NumPy. If you do `np.array([float("inf")]).astype(np.uint8, casting="safe")`, it results in an error `TypeError: Cannot cast array data from dtype('float64') to dtype('uint8') according to the rule 'safe'`. The undefined behaviors are always subject to change. This PR address this issue by passing concrete values as the ground truth references. In the future, even NumPy changes its behavior the test would still remain stable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138131 Approved by: https://github.com/drisspg
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							d32eac86f3
						
					
				
				
					commit
					e6083016b3
				
			@ -1081,11 +1081,17 @@ class TestTensorCreation(TestCase):
 | 
			
		||||
    # NB: torch.uint16, torch.uint32, torch.uint64 excluded as this
 | 
			
		||||
    # nondeterministically fails, warning "invalid value encountered in cast"
 | 
			
		||||
    @onlyCPU
 | 
			
		||||
    @unittest.skipIf(IS_MACOS, "Nonfinite conversion results on MacOS are different from others.")
 | 
			
		||||
    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
 | 
			
		||||
    def test_float_to_int_conversion_nonfinite(self, device, dtype):
 | 
			
		||||
        vals = (float('-inf'), float('inf'), float('nan'))
 | 
			
		||||
        refs = 0
 | 
			
		||||
        if dtype == torch.bool:
 | 
			
		||||
            refs = True
 | 
			
		||||
        elif dtype in (torch.int32, torch.int64):
 | 
			
		||||
            refs = torch.iinfo(dtype).min
 | 
			
		||||
 | 
			
		||||
        self._float_to_int_conversion_helper(vals, device, dtype)
 | 
			
		||||
        self._float_to_int_conversion_helper(vals, device, dtype, (refs, ) * 3)
 | 
			
		||||
 | 
			
		||||
    @onlyNativeDeviceTypes
 | 
			
		||||
    def test_complex_type_conversions(self, device):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user