mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	[MPS] Fix [nan]median output for empty tensors (#162846)
				
					
				
			It should be `NaN` rather than 0 Added respective checks to `test_empty_tensor` Fixes https://github.com/pytorch/pytorch/issues/162798 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162846 Approved by: https://github.com/dcci
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							ee53ad2dd0
						
					
				
				
					commit
					d25c35d2b2
				
			@ -617,6 +617,7 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
 | 
			
		||||
  // we allocate 1 here due to MacOS13 bug for gather MPSGraph op, look below for the error
 | 
			
		||||
  Tensor output_t = at::empty({1}, input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
 | 
			
		||||
  if (output_t.numel() == 0 || num_in_elements == 0) {
 | 
			
		||||
    output_t.fill_(std::numeric_limits<float>::quiet_NaN());
 | 
			
		||||
    return output_t;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -11628,6 +11628,9 @@ class TestAdvancedIndexing(TestCaseMPS):
 | 
			
		||||
    def test_empty_reduce(self, device="mps"):
 | 
			
		||||
        x = torch.rand(0, 3, device=device)
 | 
			
		||||
        self.assertTrue(x.mean().isnan())
 | 
			
		||||
        self.assertTrue(x.nanmean().isnan())
 | 
			
		||||
        self.assertTrue(x.median().isnan())
 | 
			
		||||
        self.assertTrue(x.nanmedian().isnan())
 | 
			
		||||
        self.assertEqual(x.count_nonzero(), 0)
 | 
			
		||||
        self.assertEqual(x.sum(), 0)
 | 
			
		||||
        self.assertEqual(x.nansum(), 0)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user