Add meta registration for _foreach_norm (2nd try) (#119927)

The first try reused TensorListMetadata, which caused illegal memory access issues when there were too many tensors in the list. We just launch multiple kernels with a simpler version of the struct (to minimize kernels launched).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119927
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-02-15 12:47:20 -08:00
committed by PyTorch MergeBot
parent 707cde9b31
commit 4319735ace
5 changed files with 120 additions and 34 deletions

View File

@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/OpMathType.h>
#include <ATen/ceil_div.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/cuda/ForeachFunctors.cuh>
@ -23,6 +24,23 @@ namespace at::native {
// _foreach_norm supports only L1, L2, and inf norm
enum class NormType { L1, L2, LInf };
// NOTE: This is a simple variant of TensorListMetadata in MultiTensorApply.cuh
// as we only need to track addresses for the lpnorm_cleanup function below.
// Why is this struct necessary? For the same reason the TensorListMetadata
// struct is necessary--which is to ferry static metadata to the CUDA kernel
// while complying with the 4kb size constraint. Since we only need to track
// addresses, we introduce this struct to be able to fit more Tensor pointers at
// a time, currently 400 empirically, compared to the much smaller values in
// depth_to_max_tensors. This way, we can launch fewer kernels for better
// performance.
//
// IF YOU USE THIS STRUCT, PLEASE ADD A ONE-OFF TEST IN test_foreach.py AS THIS
// IS CURRENTLY ONLY TESTED FOR _foreach_norm.
const size_t MAX_TENSORS_PER_KERNEL = 400;
struct TensorListAddresses {
const void* addresses[MAX_TENSORS_PER_KERNEL];
};
template <
typename T,
NormType norm_type,
@ -112,7 +130,7 @@ template <
typename opmath_t = at::opmath_type<T>>
__global__ void lpnorm_cleanup(
const opmath_t* output_per_tensor,
T* ret_per_tensor,
TensorListAddresses addr_struct,
int max_chunks_per_tensor) {
__shared__ opmath_t vals[512];
@ -130,7 +148,7 @@ __global__ void lpnorm_cleanup(
? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
: at::native::cuda_utils::BlockReduceMax(val, vals);
if (threadIdx.x == 0) {
ret_per_tensor[blockIdx.x] =
*(T*)addr_struct.addresses[blockIdx.x] =
norm_type == NormType::L1 || norm_type == NormType::LInf
? final_val
: ::sqrt(final_val);
@ -166,7 +184,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
return foreach_tensor_norm_slow(tensors, ord);
}
const int ntensors = tensors.size();
const size_t ntensors = tensors.size();
int max_chunks_per_tensor = -1;
for (int t = 0; t < ntensors; t++) {
@ -178,9 +196,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
}
const auto options = tensors[0].options();
auto output_per_tensor = at::zeros(
{ntensors * max_chunks_per_tensor},
{static_cast<int64_t>(ntensors) * max_chunks_per_tensor},
options.dtype(toOpMathType(tensors[0].scalar_type())));
auto ret_per_tensor = at::empty({ntensors}, options);
std::vector<at::Tensor> vec_res;
vec_res.reserve(ntensors);
for (int i = 0; i < ntensors; i++) {
vec_res.push_back(at::empty({}, options));
}
auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
if (p == static_cast<double>(1)) {
@ -200,11 +223,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
for (auto i = 0; i < num_kernels; i++) {
const size_t num_tensors_this_kernel =
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
? MAX_TENSORS_PER_KERNEL
: (ntensors % MAX_TENSORS_PER_KERNEL);
TensorListAddresses addr_struct;
for (auto j = 0; j < num_tensors_this_kernel; j++) {
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
.mutable_data_ptr<scalar_t>();
}
lpnorm_cleanup<scalar_t, NormType::L1>
<<<num_tensors_this_kernel, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>() +
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
addr_struct,
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
} else if (p == static_cast<double>(2)) {
AT_DISPATCH_FLOATING_TYPES_AND2(
@ -223,11 +263,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
for (auto i = 0; i < num_kernels; i++) {
const size_t num_tensors_this_kernel =
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
? MAX_TENSORS_PER_KERNEL
: (ntensors % MAX_TENSORS_PER_KERNEL);
TensorListAddresses addr_struct;
for (auto j = 0; j < num_tensors_this_kernel; j++) {
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
.mutable_data_ptr<scalar_t>();
}
lpnorm_cleanup<scalar_t, NormType::L2>
<<<num_tensors_this_kernel, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>() +
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
addr_struct,
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
} else if (p == std::numeric_limits<double>::infinity()) {
AT_DISPATCH_FLOATING_TYPES_AND2(
@ -246,12 +303,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
for (auto i = 0; i < num_kernels; i++) {
const size_t num_tensors_this_kernel =
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
? MAX_TENSORS_PER_KERNEL
: (ntensors % MAX_TENSORS_PER_KERNEL);
TensorListAddresses addr_struct;
for (auto j = 0; j < num_tensors_this_kernel; j++) {
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
.mutable_data_ptr<scalar_t>();
}
lpnorm_cleanup<scalar_t, NormType::LInf>
<<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
<<<num_tensors_this_kernel, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>() +
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
addr_struct,
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
} else {
TORCH_CHECK(
@ -267,7 +340,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
int i = 0;
for (const auto& t : tensors) {
if (t.numel() != 0) {
result.emplace_back(ret_per_tensor[i]);
result.emplace_back(vec_res[i]);
i++;
} else {
result.emplace_back(at::zeros({}, options));

View File

@ -646,6 +646,20 @@ class TestForeach(TestCase):
for i, e in enumerate(expect)))
self.assertEqual(expect, actual, equal_nan=False)
@onlyCUDA
@ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
def test_big_num_tensors(self, device, dtype, op):
N = 600
tensorlist = [make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False) for _ in range(N)]
fn, ref_fn, *_ = self._get_funcs(op)
import math
for ord in (1, 2, math.inf):
actual = fn(inputs=[tensorlist], is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
expect = ref_fn(inputs=[tensorlist], ord=ord)
self.assertEqual(expect, actual, equal_nan=True)
@onlyCUDA
@ops(foreach_reduce_op_db)
def test_foreach_reduce_large_input(self, device, dtype, op):

View File

@ -401,22 +401,24 @@ def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
if not isinstance(r, torch.Tensor):
continue
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
# See https://github.com/pytorch/pytorch/issues/78050
if should_check_strides(func) == CheckStrides.ALL:
same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
test_assert(same_strides, f"but real stride was {r.stride()}")
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
test_assert(same_strides, f"but real stride was {r.stride()}")
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
test_assert(
meta_r.storage_offset() == r.storage_offset(),
f"but real storage_offset was {r.storage_offset()}")
test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}")
f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
test_assert(meta_r.requires_grad == r.requires_grad,
f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
if func not in CHECK_CONJ_SKIPS:
test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}")
test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}")
test_assert(meta_r.is_conj() == r.is_conj(),
f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
# This environment variable controls whether or not we print expected failure

View File

@ -3073,6 +3073,7 @@ def register_meta_foreach(ops):
aten._foreach_log1p,
aten._foreach_log2,
aten._foreach_neg,
aten._foreach_norm,
aten._foreach_reciprocal,
aten._foreach_round,
aten._foreach_sigmoid,

View File

@ -9693,11 +9693,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
),
),
]