Ensure autocast device_type is a string + Unit test (#125014)

Reviving #124873 (already approved) to resolve CLA issues

Fixes #124738

(Marked as draft until I get local unit tests to run)

Edit: Tests passing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125014
Approved by: https://github.com/mikaylagawarecki, https://github.com/soulitzer
This commit is contained in:
Alana Xiang
2024-04-28 16:27:27 +00:00
committed by PyTorch MergeBot
parent 1a0b247762
commit 6761b49551
2 changed files with 11 additions and 0 deletions

View File

@ -343,6 +343,13 @@ class TestTorchAutocast(TestCase):
with self.assertRaisesRegex(RuntimeError, msg):
assert torch.amp.is_autocast_available(device_type=dev)
def test_non_string_device(self):
"""Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
dev = torch.device("cpu")
msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
torch.autocast(device_type=dev)
if __name__ == "__main__":
run_tests()