mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch][ATEN] Enable FP8 concatenate (#138046)
Summary: Float8 is becoming and increasingly popular datatype now that it is well supported on GPUs. This diff enables FP8 to work with `torch.cat`. This is pretty straight forward since memory operations dont vary based on the input dtype, but can be quite helpful for FP8 based models. Test Plan: ``` buck2 run mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.nvcc_arch=h100a -c fbcode.platform010_cuda_version=12 //caffe2/test:tensor_creation -- -r test_cat_all_dtypes_and_devices ``` Differential Revision: D64443965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138046 Approved by: https://github.com/eqy, https://github.com/qchip, https://github.com/jianyuh
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebd60f4074
commit
9c084cccfd
@ -2,9 +2,10 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/cpu/CatKernel.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/CatKernel.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at::native {
|
||||
@ -16,15 +17,19 @@ struct InputMeta {
|
||||
int64_t inner_size;
|
||||
|
||||
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
|
||||
: data_ptr(t.const_data_ptr())
|
||||
, inner_size(t.sizes()[dim] * inner) {}
|
||||
: data_ptr(t.const_data_ptr()), inner_size(t.sizes()[dim] * inner) {}
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListRef& tensors, int64_t dim) {
|
||||
void cat_serial_kernel_impl(
|
||||
const Tensor& result,
|
||||
const MaterializedITensorListRef& tensors,
|
||||
int64_t dim) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl");
|
||||
int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]);
|
||||
dim >= 0 && dim < result.dim(),
|
||||
"dim out of range in cat_serial_kernel_impl");
|
||||
int64_t outer =
|
||||
result.numel() / (result.sizes()[dim] * result.strides()[dim]);
|
||||
scalar_t* result_data = result.data_ptr<scalar_t>();
|
||||
int64_t ninputs = static_cast<int64_t>(tensors.size());
|
||||
std::vector<InputMeta> inputs;
|
||||
@ -38,15 +43,16 @@ void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListR
|
||||
for (const auto i : c10::irange(outer)) {
|
||||
for (const auto j : c10::irange(ninputs)) {
|
||||
int64_t local_inner = inputs[j].inner_size;
|
||||
const scalar_t* input_ptr = (const scalar_t*)(inputs[j].data_ptr) + i * local_inner;
|
||||
const scalar_t* input_ptr =
|
||||
(const scalar_t*)(inputs[j].data_ptr) + i * local_inner;
|
||||
int64_t d = 0;
|
||||
for (; d < local_inner - (local_inner % Vec::size()); d += Vec::size()) {
|
||||
Vec in_vec = Vec::loadu(input_ptr + d);
|
||||
in_vec.store(result_ptr + d);
|
||||
}
|
||||
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
|
||||
# pragma unroll
|
||||
#endif
|
||||
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (; d < local_inner; d++) {
|
||||
result_ptr[d] = input_ptr[d];
|
||||
}
|
||||
@ -55,14 +61,23 @@ void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListR
|
||||
}
|
||||
}
|
||||
|
||||
void cat_serial_kernel(const Tensor& result, const MaterializedITensorListRef& tensors, int64_t dim) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, result.scalar_type(), "cat_serial_kernel", [&]() {
|
||||
cat_serial_kernel_impl<scalar_t>(result, tensors, dim);
|
||||
});
|
||||
void cat_serial_kernel(
|
||||
const Tensor& result,
|
||||
const MaterializedITensorListRef& tensors,
|
||||
int64_t dim) {
|
||||
AT_DISPATCH_V2(
|
||||
result.scalar_type(),
|
||||
"cat_serial_kernel",
|
||||
AT_WRAP(
|
||||
[&]() { cat_serial_kernel_impl<scalar_t>(result, tensors, dim); }),
|
||||
AT_EXPAND(AT_FLOATING_TYPES),
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
AT_EXPAND(AT_FLOAT8_TYPES));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_DISPATCH(cat_serial_stub, &cat_serial_kernel);
|
||||
|
||||
} // at::native
|
||||
} // namespace at::native
|
||||
|
@ -500,10 +500,21 @@ TORCH_IMPL_FUNC(cat_out_cuda)
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
AT_DISPATCH_V2(
|
||||
result.scalar_type(),
|
||||
"cat_cuda",
|
||||
AT_WRAP([&]() {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(
|
||||
result, materialized, dim, nDims, memory_format);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16,
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
} else if (materialized.size() > 1 &&
|
||||
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
|
||||
@ -518,10 +529,27 @@ TORCH_IMPL_FUNC(cat_out_cuda)
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
AT_DISPATCH_V2(
|
||||
result.scalar_type(),
|
||||
"cat_cuda",
|
||||
AT_WRAP([&]() {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
parallel_cat<
|
||||
dtype,
|
||||
CAT_ARRAY_BATCH_SIZE / 2,
|
||||
CAT_ARRAY_BATCH_SIZE / 2>(
|
||||
result, materialized, dim, nDims, memory_format);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16,
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e5m2fnuz,
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
} else {
|
||||
int64_t offset = 0;
|
||||
|
@ -14,11 +14,27 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
||||
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
|
||||
set_default_dtype, set_default_tensor_type,
|
||||
TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo,
|
||||
xfailIfTorchDynamo)
|
||||
TestCase,
|
||||
run_tests,
|
||||
do_test_empty_full,
|
||||
TEST_WITH_ROCM,
|
||||
suppress_warnings,
|
||||
torch_to_numpy_dtype_dict,
|
||||
numpy_to_torch_dtype_dict,
|
||||
slowTest,
|
||||
set_default_dtype,
|
||||
set_default_tensor_type,
|
||||
TEST_SCIPY,
|
||||
IS_MACOS,
|
||||
IS_PPC,
|
||||
IS_JETSON,
|
||||
IS_WINDOWS,
|
||||
IS_FBCODE,
|
||||
IS_SANDCASTLE,
|
||||
parametrize,
|
||||
skipIfTorchDynamo,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
|
||||
onlyCPU, largeTensorTest, precisionOverride, dtypes,
|
||||
@ -148,7 +164,16 @@ class TestTensorCreation(TestCase):
|
||||
exact_dtype=False)
|
||||
|
||||
def test_cat_all_dtypes_and_devices(self, device):
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf):
|
||||
for dt in all_types_and_complex_and(
|
||||
torch.half,
|
||||
torch.bool,
|
||||
torch.bfloat16,
|
||||
torch.chalf,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
):
|
||||
x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device)
|
||||
|
||||
expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device)
|
||||
@ -1046,6 +1071,9 @@ class TestTensorCreation(TestCase):
|
||||
# Note: numpy -2.0 or -1.5 -> uint8 conversion is undefined
|
||||
# see https://github.com/pytorch/pytorch/issues/97794
|
||||
refs = (0, 254, 255, 0, 0, 0, 1, 2)
|
||||
elif dtype == torch.int16:
|
||||
# CPU min and max float -> int16 conversion is divergent.
|
||||
vals = (-2, -1.5, -.5, 0, .5, 1.5, 2)
|
||||
|
||||
self._float_to_int_conversion_helper(vals, device, dtype, refs)
|
||||
|
||||
@ -3556,6 +3584,7 @@ class TestRandomTensorCreation(TestCase):
|
||||
|
||||
# Test exceptions when device and generator types are incompatible
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Produces inconsistent errors when run in fbcode.")
|
||||
def test_randperm_device_compatibility(self, device):
|
||||
cuda_gen = torch.Generator(device='cuda')
|
||||
cpu_gen = torch.Generator(device='cpu')
|
||||
|
Reference in New Issue
Block a user