additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)

Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10, https://github.com/malfet
This commit is contained in:
Jeff Daily
2024-01-22 18:33:41 +00:00
committed by PyTorch MergeBot
parent 56ef5afdee
commit 01abb5af21
43 changed files with 708 additions and 625 deletions

View File

@ -207,6 +207,14 @@ class _StorageBase:
"""Casts this storage to float8_e4m3fn type"""
return self._to(torch.float8_e4m3fn)
def float8_e5m2fnuz(self):
"""Casts this storage to float8_e5m2fnuz type"""
return self._to(torch.float8_e5m2fnuz)
def float8_e4m3fnuz(self):
"""Casts this storage to float8_e4m3fnuz type"""
return self._to(torch.float8_e4m3fnuz)
def is_pinned(self, device: Union[str, torch.device] = 'cuda'):
r"""Determine whether the CPU storage is already pinned on device.
@ -1070,6 +1078,16 @@ class TypedStorage:
_warn_typed_storage_removal()
return self._to(torch.float8_e4m3fn)
def float8_e5m2fnuz(self):
"""Casts this storage to float8_e5m2fnuz type"""
_warn_typed_storage_removal()
return self._to(torch.float8_e5m2fnuz)
def float8_e4m3fnuz(self):
"""Casts this storage to float8_e4m3fnuz type"""
_warn_typed_storage_removal()
return self._to(torch.float8_e4m3fnuz)
@classmethod
def from_file(cls, filename, shared, size):
"""from_file(filename, shared=False, size=0) -> Storage