[redo] Fp8 support for item() with cuda, index_select, and fill_ cpu (#137341)

Summary:

Redo of https://github.com/pytorch/pytorch/pull/128780, easier to copy-paste.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137341
Approved by: https://github.com/eqy
This commit is contained in:
vasiliy
2024-10-07 00:58:51 +00:00
committed by PyTorch MergeBot
parent d1b87e26e5
commit a063a82c8b
8 changed files with 52 additions and 18 deletions

View File

@ -26,7 +26,7 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_dtype import (
all_types_and_complex, all_types_and_complex_and, all_types_and, floating_and_complex_types, complex_types,
floating_types, floating_and_complex_types_and, integral_types, integral_types_and, get_all_dtypes,
float_to_corresponding_complex_type_map
float_to_corresponding_complex_type_map, all_types_complex_float8_and
)
from torch.utils.dlpack import to_dlpack
@ -158,11 +158,13 @@ class TestTensorCreation(TestCase):
self.assertEqual(torch.cat((x, x), 1), expected2)
def test_fill_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
for dt in all_types_complex_float8_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
for x in [torch.tensor((10, 10), dtype=dt, device=device),
torch.empty(10000, dtype=dt, device=device)]: # large tensor
numel = x.numel()
bound = 100 if dt in (torch.uint8, torch.int8) else 2000
bound_dtypes = (torch.uint8, torch.int8, torch.float8_e4m3fn,
torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)
bound = 100 if dt in bound_dtypes else 2000
for n in range(-bound, bound, bound // 10):
x.fill_(n)
self.assertEqual(x, torch.tensor([n] * numel, dtype=dt, device=device))