mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch][ATEN] Enable FP8 concatenate (#138046)
Summary: Float8 is becoming and increasingly popular datatype now that it is well supported on GPUs. This diff enables FP8 to work with `torch.cat`. This is pretty straight forward since memory operations dont vary based on the input dtype, but can be quite helpful for FP8 based models. Test Plan: ``` buck2 run mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.nvcc_arch=h100a -c fbcode.platform010_cuda_version=12 //caffe2/test:tensor_creation -- -r test_cat_all_dtypes_and_devices ``` Differential Revision: D64443965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138046 Approved by: https://github.com/eqy, https://github.com/qchip, https://github.com/jianyuh
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebd60f4074
commit
9c084cccfd
@ -14,11 +14,27 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
||||
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
|
||||
set_default_dtype, set_default_tensor_type,
|
||||
TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo,
|
||||
xfailIfTorchDynamo)
|
||||
TestCase,
|
||||
run_tests,
|
||||
do_test_empty_full,
|
||||
TEST_WITH_ROCM,
|
||||
suppress_warnings,
|
||||
torch_to_numpy_dtype_dict,
|
||||
numpy_to_torch_dtype_dict,
|
||||
slowTest,
|
||||
set_default_dtype,
|
||||
set_default_tensor_type,
|
||||
TEST_SCIPY,
|
||||
IS_MACOS,
|
||||
IS_PPC,
|
||||
IS_JETSON,
|
||||
IS_WINDOWS,
|
||||
IS_FBCODE,
|
||||
IS_SANDCASTLE,
|
||||
parametrize,
|
||||
skipIfTorchDynamo,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
|
||||
onlyCPU, largeTensorTest, precisionOverride, dtypes,
|
||||
@ -148,7 +164,16 @@ class TestTensorCreation(TestCase):
|
||||
exact_dtype=False)
|
||||
|
||||
def test_cat_all_dtypes_and_devices(self, device):
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
|
||||
for dt in all_types_and_complex_and(
|
||||
torch.half,
|
||||
torch.bool,
|
||||
torch.bfloat16,
|
||||
torch.chalf,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
):
|
||||
x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device)
|
||||
|
||||
expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device)
|
||||
@ -1046,6 +1071,9 @@ class TestTensorCreation(TestCase):
|
||||
# Note: numpy -2.0 or -1.5 -> uint8 conversion is undefined
|
||||
# see https://github.com/pytorch/pytorch/issues/97794
|
||||
refs = (0, 254, 255, 0, 0, 0, 1, 2)
|
||||
elif dtype == torch.int16:
|
||||
# CPU min and max float -> int16 conversion is divergent.
|
||||
vals = (-2, -1.5, -.5, 0, .5, 1.5, 2)
|
||||
|
||||
self._float_to_int_conversion_helper(vals, device, dtype, refs)
|
||||
|
||||
@ -3556,6 +3584,7 @@ class TestRandomTensorCreation(TestCase):
|
||||
|
||||
# Test exceptions when device and generator types are incompatible
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Produces inconsistent errors when run in fbcode.")
|
||||
def test_randperm_device_compatibility(self, device):
|
||||
cuda_gen = torch.Generator(device='cuda')
|
||||
cpu_gen = torch.Generator(device='cpu')
|
||||
|
||||
Reference in New Issue
Block a user