[Pytorch][ATEN] Enable FP8 concatenate (#138046)

Summary: Float8 is becoming and increasingly popular datatype now that it is well supported on GPUs. This  diff enables FP8 to work with `torch.cat`. This is pretty straight forward since memory operations dont vary based on the input dtype, but can be quite helpful for FP8 based models.

Test Plan:
```
buck2 run mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.nvcc_arch=h100a -c fbcode.platform010_cuda_version=12 //caffe2/test:tensor_creation -- -r test_cat_all_dtypes_and_devices
```

Differential Revision: D64443965

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138046
Approved by: https://github.com/eqy, https://github.com/qchip, https://github.com/jianyuh
This commit is contained in:
Josh Fromm
2024-10-17 04:58:54 +00:00
committed by PyTorch MergeBot
parent ebd60f4074
commit 9c084cccfd
3 changed files with 101 additions and 29 deletions

View File

@ -14,11 +14,27 @@ from typing import Any, Dict, List, Tuple
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
set_default_dtype, set_default_tensor_type,
TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo,
xfailIfTorchDynamo)
TestCase,
run_tests,
do_test_empty_full,
TEST_WITH_ROCM,
suppress_warnings,
torch_to_numpy_dtype_dict,
numpy_to_torch_dtype_dict,
slowTest,
set_default_dtype,
set_default_tensor_type,
TEST_SCIPY,
IS_MACOS,
IS_PPC,
IS_JETSON,
IS_WINDOWS,
IS_FBCODE,
IS_SANDCASTLE,
parametrize,
skipIfTorchDynamo,
xfailIfTorchDynamo,
)
from torch.testing._internal.common_device_type import (
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
onlyCPU, largeTensorTest, precisionOverride, dtypes,
@ -148,7 +164,16 @@ class TestTensorCreation(TestCase):
exact_dtype=False)
def test_cat_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
for dt in all_types_and_complex_and(
torch.half,
torch.bool,
torch.bfloat16,
torch.chalf,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
):
x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device)
expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device)
@ -1046,6 +1071,9 @@ class TestTensorCreation(TestCase):
# Note: numpy -2.0 or -1.5 -> uint8 conversion is undefined
# see https://github.com/pytorch/pytorch/issues/97794
refs = (0, 254, 255, 0, 0, 0, 1, 2)
elif dtype == torch.int16:
# CPU min and max float -> int16 conversion is divergent.
vals = (-2, -1.5, -.5, 0, .5, 1.5, 2)
self._float_to_int_conversion_helper(vals, device, dtype, refs)
@ -3556,6 +3584,7 @@ class TestRandomTensorCreation(TestCase):
# Test exceptions when device and generator types are incompatible
@onlyCUDA
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Produces inconsistent errors when run in fbcode.")
def test_randperm_device_compatibility(self, device):
cuda_gen = torch.Generator(device='cuda')
cpu_gen = torch.Generator(device='cpu')