[Pytorch][ATEN] Enable FP8 concatenate (#138046)

Summary: Float8 is becoming and increasingly popular datatype now that it is well supported on GPUs. This  diff enables FP8 to work with `torch.cat`. This is pretty straight forward since memory operations dont vary based on the input dtype, but can be quite helpful for FP8 based models.

Test Plan:
```
buck2 run mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.nvcc_arch=h100a -c fbcode.platform010_cuda_version=12 //caffe2/test:tensor_creation -- -r test_cat_all_dtypes_and_devices
```

Differential Revision: D64443965

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138046
Approved by: https://github.com/eqy, https://github.com/qchip, https://github.com/jianyuh
This commit is contained in:
Josh Fromm
2024-10-17 04:58:54 +00:00
committed by PyTorch MergeBot
parent ebd60f4074
commit 9c084cccfd
3 changed files with 101 additions and 29 deletions

View File

@ -2,9 +2,10 @@
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/CatKernel.h>
#include <c10/util/irange.h>
namespace at::native {
@ -16,15 +17,19 @@ struct InputMeta {
int64_t inner_size;
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
: data_ptr(t.const_data_ptr())
, inner_size(t.sizes()[dim] * inner) {}
: data_ptr(t.const_data_ptr()), inner_size(t.sizes()[dim] * inner) {}
};
template <typename scalar_t>
void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListRef& tensors, int64_t dim) {
void cat_serial_kernel_impl(
const Tensor& result,
const MaterializedITensorListRef& tensors,
int64_t dim) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl");
int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]);
dim >= 0 && dim < result.dim(),
"dim out of range in cat_serial_kernel_impl");
int64_t outer =
result.numel() / (result.sizes()[dim] * result.strides()[dim]);
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = static_cast<int64_t>(tensors.size());
std::vector<InputMeta> inputs;
@ -38,15 +43,16 @@ void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListR
for (const auto i : c10::irange(outer)) {
for (const auto j : c10::irange(ninputs)) {
int64_t local_inner = inputs[j].inner_size;
const scalar_t* input_ptr = (const scalar_t*)(inputs[j].data_ptr) + i * local_inner;
const scalar_t* input_ptr =
(const scalar_t*)(inputs[j].data_ptr) + i * local_inner;
int64_t d = 0;
for (; d < local_inner - (local_inner % Vec::size()); d += Vec::size()) {
Vec in_vec = Vec::loadu(input_ptr + d);
in_vec.store(result_ptr + d);
}
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
# pragma unroll
#endif
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
#pragma unroll
#endif
for (; d < local_inner; d++) {
result_ptr[d] = input_ptr[d];
}
@ -55,14 +61,23 @@ void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListR
}
}
void cat_serial_kernel(const Tensor& result, const MaterializedITensorListRef& tensors, int64_t dim) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, result.scalar_type(), "cat_serial_kernel", [&]() {
cat_serial_kernel_impl<scalar_t>(result, tensors, dim);
});
void cat_serial_kernel(
const Tensor& result,
const MaterializedITensorListRef& tensors,
int64_t dim) {
AT_DISPATCH_V2(
result.scalar_type(),
"cat_serial_kernel",
AT_WRAP(
[&]() { cat_serial_kernel_impl<scalar_t>(result, tensors, dim); }),
AT_EXPAND(AT_FLOATING_TYPES),
kBFloat16,
kHalf,
AT_EXPAND(AT_FLOAT8_TYPES));
}
} // anonymous namespace
REGISTER_DISPATCH(cat_serial_stub, &cat_serial_kernel);
} // at::native
} // namespace at::native

View File

@ -500,10 +500,21 @@ TORCH_IMPL_FUNC(cat_out_cuda)
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
AT_DISPATCH_V2(
result.scalar_type(),
"cat_cuda",
AT_WRAP([&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(
result, materialized, dim, nDims, memory_format);
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
kComplexHalf,
kHalf,
kBool,
kBFloat16,
AT_EXPAND(AT_FLOAT8_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
} else if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
@ -518,10 +529,27 @@ TORCH_IMPL_FUNC(cat_out_cuda)
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
AT_DISPATCH_V2(
result.scalar_type(),
"cat_cuda",
AT_WRAP([&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<
dtype,
CAT_ARRAY_BATCH_SIZE / 2,
CAT_ARRAY_BATCH_SIZE / 2>(
result, materialized, dim, nDims, memory_format);
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
kComplexHalf,
kHalf,
kBool,
kBFloat16,
kFloat8_e4m3fn,
kFloat8_e4m3fnuz,
kFloat8_e5m2,
kFloat8_e5m2fnuz,
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
} else {
int64_t offset = 0;

View File

@ -14,11 +14,27 @@ from typing import Any, Dict, List, Tuple
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
set_default_dtype, set_default_tensor_type,
TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo,
xfailIfTorchDynamo)
TestCase,
run_tests,
do_test_empty_full,
TEST_WITH_ROCM,
suppress_warnings,
torch_to_numpy_dtype_dict,
numpy_to_torch_dtype_dict,
slowTest,
set_default_dtype,
set_default_tensor_type,
TEST_SCIPY,
IS_MACOS,
IS_PPC,
IS_JETSON,
IS_WINDOWS,
IS_FBCODE,
IS_SANDCASTLE,
parametrize,
skipIfTorchDynamo,
xfailIfTorchDynamo,
)
from torch.testing._internal.common_device_type import (
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
onlyCPU, largeTensorTest, precisionOverride, dtypes,
@ -148,7 +164,16 @@ class TestTensorCreation(TestCase):
exact_dtype=False)
def test_cat_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
for dt in all_types_and_complex_and(
torch.half,
torch.bool,
torch.bfloat16,
torch.chalf,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
):
x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device)
expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device)
@ -1046,6 +1071,9 @@ class TestTensorCreation(TestCase):
# Note: numpy -2.0 or -1.5 -> uint8 conversion is undefined
# see https://github.com/pytorch/pytorch/issues/97794
refs = (0, 254, 255, 0, 0, 0, 1, 2)
elif dtype == torch.int16:
# CPU min and max float -> int16 conversion is divergent.
vals = (-2, -1.5, -.5, 0, .5, 1.5, 2)
self._float_to_int_conversion_helper(vals, device, dtype, refs)
@ -3556,6 +3584,7 @@ class TestRandomTensorCreation(TestCase):
# Test exceptions when device and generator types are incompatible
@onlyCUDA
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Produces inconsistent errors when run in fbcode.")
def test_randperm_device_compatibility(self, device):
cuda_gen = torch.Generator(device='cuda')
cpu_gen = torch.Generator(device='cpu')