mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Ensure autocast device_type is a string + Unit test (#125014)
Reviving #124873 (already approved) to resolve CLA issues Fixes #124738 (Marked as draft until I get local unit tests to run) Edit: Tests passing Pull Request resolved: https://github.com/pytorch/pytorch/pull/125014 Approved by: https://github.com/mikaylagawarecki, https://github.com/soulitzer
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							1a0b247762
						
					
				
				
					commit
					6761b49551
				
			@ -343,6 +343,13 @@ class TestTorchAutocast(TestCase):
 | 
				
			|||||||
        with self.assertRaisesRegex(RuntimeError, msg):
 | 
					        with self.assertRaisesRegex(RuntimeError, msg):
 | 
				
			||||||
            assert torch.amp.is_autocast_available(device_type=dev)
 | 
					            assert torch.amp.is_autocast_available(device_type=dev)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_non_string_device(self):
 | 
				
			||||||
 | 
					        """Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
 | 
				
			||||||
 | 
					        dev = torch.device("cpu")
 | 
				
			||||||
 | 
					        msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
 | 
				
			||||||
 | 
					        with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
 | 
				
			||||||
 | 
					            torch.autocast(device_type=dev)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    run_tests()
 | 
					    run_tests()
 | 
				
			||||||
 | 
				
			|||||||
@ -203,6 +203,10 @@ class autocast:
 | 
				
			|||||||
        enabled: bool = True,
 | 
					        enabled: bool = True,
 | 
				
			||||||
        cache_enabled: Optional[bool] = None,
 | 
					        cache_enabled: Optional[bool] = None,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
 | 
					        if not isinstance(device_type, str):
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        if torch._jit_internal.is_scripting():
 | 
					        if torch._jit_internal.is_scripting():
 | 
				
			||||||
            self._enabled = enabled
 | 
					            self._enabled = enabled
 | 
				
			||||||
            self.device = device_type
 | 
					            self.device = device_type
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user