mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Ensure autocast device_type is a string + Unit test (#125014)
Reviving #124873 (already approved) to resolve CLA issues Fixes #124738 (Marked as draft until I get local unit tests to run) Edit: Tests passing Pull Request resolved: https://github.com/pytorch/pytorch/pull/125014 Approved by: https://github.com/mikaylagawarecki, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a0b247762
commit
6761b49551
@ -343,6 +343,13 @@ class TestTorchAutocast(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
assert torch.amp.is_autocast_available(device_type=dev)
|
||||
|
||||
def test_non_string_device(self):
|
||||
"""Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
|
||||
dev = torch.device("cpu")
|
||||
msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
|
||||
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
|
||||
torch.autocast(device_type=dev)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user