mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Perform reciprocal optimization with foreach_div (#128433)
Fixes https://github.com/pytorch/pytorch/issues/114165 Internal xref https://fb.workplace.com/groups/1144215345733672/posts/2801223606699496/ Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/128433 Approved by: https://github.com/awgu
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							8db4a41973
						
					
				
				
					commit
					2fa6f80b13
				
			@ -1206,6 +1206,17 @@ class TestForeach(TestCase):
 | 
			
		||||
        actual = torch._foreach_div(tensors, scalar_cpu_tensor)
 | 
			
		||||
        self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])
 | 
			
		||||
 | 
			
		||||
    @onlyCUDA
 | 
			
		||||
    def test_div_reciprocal(self):
 | 
			
		||||
        expect_m, expect_e = torch.frexp(
 | 
			
		||||
            torch.div(torch.tensor(0.1, device="cuda"), 10.0)
 | 
			
		||||
        )
 | 
			
		||||
        actual_m, actual_e = torch.frexp(
 | 
			
		||||
            torch._foreach_div([torch.tensor(0.1, device="cuda")], [10.0])[0]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(expect_m, actual_m)
 | 
			
		||||
        self.assertEqual(expect_e, actual_e)
 | 
			
		||||
 | 
			
		||||
    @onlyCUDA
 | 
			
		||||
    def test_0dim_tensor_overload_exception(self):
 | 
			
		||||
        # check exceptions of fast path
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user