mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add meta registration for _foreach_norm (2nd try) (#119927)
The first try reused TensorListMetadata, which caused illegal memory access issues when there were too many tensors in the list. We just launch multiple kernels with a simpler version of the struct (to minimize kernels launched). Pull Request resolved: https://github.com/pytorch/pytorch/pull/119927 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
707cde9b31
commit
4319735ace
@ -2,6 +2,7 @@
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/cuda/DeviceUtils.cuh>
|
||||
#include <ATen/native/cuda/ForeachFunctors.cuh>
|
||||
@ -23,6 +24,23 @@ namespace at::native {
|
||||
// _foreach_norm supports only L1, L2, and inf norm
|
||||
enum class NormType { L1, L2, LInf };
|
||||
|
||||
// NOTE: This is a simple variant of TensorListMetadata in MultiTensorApply.cuh
|
||||
// as we only need to track addresses for the lpnorm_cleanup function below.
|
||||
// Why is this struct necessary? For the same reason the TensorListMetadata
|
||||
// struct is necessary--which is to ferry static metadata to the CUDA kernel
|
||||
// while complying with the 4kb size constraint. Since we only need to track
|
||||
// addresses, we introduce this struct to be able to fit more Tensor pointers at
|
||||
// a time, currently 400 empirically, compared to the much smaller values in
|
||||
// depth_to_max_tensors. This way, we can launch fewer kernels for better
|
||||
// performance.
|
||||
//
|
||||
// IF YOU USE THIS STRUCT, PLEASE ADD A ONE-OFF TEST IN test_foreach.py AS THIS
|
||||
// IS CURRENTLY ONLY TESTED FOR _foreach_norm.
|
||||
const size_t MAX_TENSORS_PER_KERNEL = 400;
|
||||
struct TensorListAddresses {
|
||||
const void* addresses[MAX_TENSORS_PER_KERNEL];
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
NormType norm_type,
|
||||
@ -112,7 +130,7 @@ template <
|
||||
typename opmath_t = at::opmath_type<T>>
|
||||
__global__ void lpnorm_cleanup(
|
||||
const opmath_t* output_per_tensor,
|
||||
T* ret_per_tensor,
|
||||
TensorListAddresses addr_struct,
|
||||
int max_chunks_per_tensor) {
|
||||
__shared__ opmath_t vals[512];
|
||||
|
||||
@ -130,7 +148,7 @@ __global__ void lpnorm_cleanup(
|
||||
? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
|
||||
: at::native::cuda_utils::BlockReduceMax(val, vals);
|
||||
if (threadIdx.x == 0) {
|
||||
ret_per_tensor[blockIdx.x] =
|
||||
*(T*)addr_struct.addresses[blockIdx.x] =
|
||||
norm_type == NormType::L1 || norm_type == NormType::LInf
|
||||
? final_val
|
||||
: ::sqrt(final_val);
|
||||
@ -166,7 +184,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
return foreach_tensor_norm_slow(tensors, ord);
|
||||
}
|
||||
|
||||
const int ntensors = tensors.size();
|
||||
const size_t ntensors = tensors.size();
|
||||
int max_chunks_per_tensor = -1;
|
||||
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
@ -178,9 +196,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
}
|
||||
const auto options = tensors[0].options();
|
||||
auto output_per_tensor = at::zeros(
|
||||
{ntensors * max_chunks_per_tensor},
|
||||
{static_cast<int64_t>(ntensors) * max_chunks_per_tensor},
|
||||
options.dtype(toOpMathType(tensors[0].scalar_type())));
|
||||
auto ret_per_tensor = at::empty({ntensors}, options);
|
||||
|
||||
std::vector<at::Tensor> vec_res;
|
||||
vec_res.reserve(ntensors);
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
vec_res.push_back(at::empty({}, options));
|
||||
}
|
||||
|
||||
auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
|
||||
if (p == static_cast<double>(1)) {
|
||||
@ -200,11 +223,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
device_of(output_per_tensor));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
|
||||
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
|
||||
for (auto i = 0; i < num_kernels; i++) {
|
||||
const size_t num_tensors_this_kernel =
|
||||
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
|
||||
? MAX_TENSORS_PER_KERNEL
|
||||
: (ntensors % MAX_TENSORS_PER_KERNEL);
|
||||
|
||||
TensorListAddresses addr_struct;
|
||||
for (auto j = 0; j < num_tensors_this_kernel; j++) {
|
||||
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
|
||||
.mutable_data_ptr<scalar_t>();
|
||||
}
|
||||
|
||||
lpnorm_cleanup<scalar_t, NormType::L1>
|
||||
<<<num_tensors_this_kernel, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>() +
|
||||
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
|
||||
addr_struct,
|
||||
max_chunks_per_tensor);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
});
|
||||
} else if (p == static_cast<double>(2)) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
@ -223,11 +263,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
device_of(output_per_tensor));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
|
||||
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
|
||||
for (auto i = 0; i < num_kernels; i++) {
|
||||
const size_t num_tensors_this_kernel =
|
||||
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
|
||||
? MAX_TENSORS_PER_KERNEL
|
||||
: (ntensors % MAX_TENSORS_PER_KERNEL);
|
||||
|
||||
TensorListAddresses addr_struct;
|
||||
for (auto j = 0; j < num_tensors_this_kernel; j++) {
|
||||
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
|
||||
.mutable_data_ptr<scalar_t>();
|
||||
}
|
||||
|
||||
lpnorm_cleanup<scalar_t, NormType::L2>
|
||||
<<<num_tensors_this_kernel, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>() +
|
||||
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
|
||||
addr_struct,
|
||||
max_chunks_per_tensor);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
});
|
||||
} else if (p == std::numeric_limits<double>::infinity()) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
@ -246,12 +303,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
device_of(output_per_tensor));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
|
||||
for (auto i = 0; i < num_kernels; i++) {
|
||||
const size_t num_tensors_this_kernel =
|
||||
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
|
||||
? MAX_TENSORS_PER_KERNEL
|
||||
: (ntensors % MAX_TENSORS_PER_KERNEL);
|
||||
|
||||
TensorListAddresses addr_struct;
|
||||
for (auto j = 0; j < num_tensors_this_kernel; j++) {
|
||||
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
|
||||
.mutable_data_ptr<scalar_t>();
|
||||
}
|
||||
|
||||
lpnorm_cleanup<scalar_t, NormType::LInf>
|
||||
<<<ntensors, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
<<<num_tensors_this_kernel, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>() +
|
||||
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
|
||||
addr_struct,
|
||||
max_chunks_per_tensor);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
});
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
@ -267,7 +340,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
int i = 0;
|
||||
for (const auto& t : tensors) {
|
||||
if (t.numel() != 0) {
|
||||
result.emplace_back(ret_per_tensor[i]);
|
||||
result.emplace_back(vec_res[i]);
|
||||
i++;
|
||||
} else {
|
||||
result.emplace_back(at::zeros({}, options));
|
||||
|
@ -646,6 +646,20 @@ class TestForeach(TestCase):
|
||||
for i, e in enumerate(expect)))
|
||||
self.assertEqual(expect, actual, equal_nan=False)
|
||||
|
||||
@onlyCUDA
|
||||
@ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
|
||||
def test_big_num_tensors(self, device, dtype, op):
|
||||
N = 600
|
||||
tensorlist = [make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False) for _ in range(N)]
|
||||
fn, ref_fn, *_ = self._get_funcs(op)
|
||||
|
||||
import math
|
||||
for ord in (1, 2, math.inf):
|
||||
actual = fn(inputs=[tensorlist], is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
|
||||
expect = ref_fn(inputs=[tensorlist], ord=ord)
|
||||
|
||||
self.assertEqual(expect, actual, equal_nan=True)
|
||||
|
||||
@onlyCUDA
|
||||
@ops(foreach_reduce_op_db)
|
||||
def test_foreach_reduce_large_input(self, device, dtype, op):
|
||||
|
@ -401,22 +401,24 @@ def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
|
||||
if not isinstance(r, torch.Tensor):
|
||||
continue
|
||||
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
|
||||
test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
|
||||
test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
|
||||
test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
|
||||
test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
|
||||
# See https://github.com/pytorch/pytorch/issues/78050
|
||||
if should_check_strides(func) == CheckStrides.ALL:
|
||||
same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
|
||||
test_assert(same_strides, f"but real stride was {r.stride()}")
|
||||
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
|
||||
elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
|
||||
same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
|
||||
test_assert(same_strides, f"but real stride was {r.stride()}")
|
||||
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
|
||||
test_assert(
|
||||
meta_r.storage_offset() == r.storage_offset(),
|
||||
f"but real storage_offset was {r.storage_offset()}")
|
||||
test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}")
|
||||
f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
|
||||
test_assert(meta_r.requires_grad == r.requires_grad,
|
||||
f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
|
||||
if func not in CHECK_CONJ_SKIPS:
|
||||
test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}")
|
||||
test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}")
|
||||
test_assert(meta_r.is_conj() == r.is_conj(),
|
||||
f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
|
||||
test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
|
||||
|
||||
|
||||
# This environment variable controls whether or not we print expected failure
|
||||
|
@ -3073,6 +3073,7 @@ def register_meta_foreach(ops):
|
||||
aten._foreach_log1p,
|
||||
aten._foreach_log2,
|
||||
aten._foreach_neg,
|
||||
aten._foreach_norm,
|
||||
aten._foreach_reciprocal,
|
||||
aten._foreach_round,
|
||||
aten._foreach_sigmoid,
|
||||
|
@ -9693,11 +9693,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
Reference in New Issue
Block a user