mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Ensure autocast device_type is a string + Unit test (#125014)
Reviving #124873 (already approved) to resolve CLA issues Fixes #124738 (Marked as draft until I get local unit tests to run) Edit: Tests passing Pull Request resolved: https://github.com/pytorch/pytorch/pull/125014 Approved by: https://github.com/mikaylagawarecki, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a0b247762
commit
6761b49551
@ -343,6 +343,13 @@ class TestTorchAutocast(TestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
assert torch.amp.is_autocast_available(device_type=dev)
|
assert torch.amp.is_autocast_available(device_type=dev)
|
||||||
|
|
||||||
|
def test_non_string_device(self):
|
||||||
|
"""Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
|
||||||
|
dev = torch.device("cpu")
|
||||||
|
msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
|
||||||
|
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
|
||||||
|
torch.autocast(device_type=dev)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -203,6 +203,10 @@ class autocast:
|
|||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
cache_enabled: Optional[bool] = None,
|
cache_enabled: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
|
if not isinstance(device_type, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||||
|
)
|
||||||
if torch._jit_internal.is_scripting():
|
if torch._jit_internal.is_scripting():
|
||||||
self._enabled = enabled
|
self._enabled = enabled
|
||||||
self.device = device_type
|
self.device = device_type
|
||||||
|
Reference in New Issue
Block a user