mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add meta registration for _foreach_norm (2nd try) (#119927)
The first try reused TensorListMetadata, which caused illegal memory access issues when there were too many tensors in the list. We just launch multiple kernels with a simpler version of the struct (to minimize kernels launched). Pull Request resolved: https://github.com/pytorch/pytorch/pull/119927 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
707cde9b31
commit
4319735ace
@ -2,6 +2,7 @@
|
|||||||
#include <ATen/AccumulateType.h>
|
#include <ATen/AccumulateType.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
#include <ATen/OpMathType.h>
|
#include <ATen/OpMathType.h>
|
||||||
|
#include <ATen/ceil_div.h>
|
||||||
#include <ATen/native/ForeachUtils.h>
|
#include <ATen/native/ForeachUtils.h>
|
||||||
#include <ATen/cuda/DeviceUtils.cuh>
|
#include <ATen/cuda/DeviceUtils.cuh>
|
||||||
#include <ATen/native/cuda/ForeachFunctors.cuh>
|
#include <ATen/native/cuda/ForeachFunctors.cuh>
|
||||||
@ -23,6 +24,23 @@ namespace at::native {
|
|||||||
// _foreach_norm supports only L1, L2, and inf norm
|
// _foreach_norm supports only L1, L2, and inf norm
|
||||||
enum class NormType { L1, L2, LInf };
|
enum class NormType { L1, L2, LInf };
|
||||||
|
|
||||||
|
// NOTE: This is a simple variant of TensorListMetadata in MultiTensorApply.cuh
|
||||||
|
// as we only need to track addresses for the lpnorm_cleanup function below.
|
||||||
|
// Why is this struct necessary? For the same reason the TensorListMetadata
|
||||||
|
// struct is necessary--which is to ferry static metadata to the CUDA kernel
|
||||||
|
// while complying with the 4kb size constraint. Since we only need to track
|
||||||
|
// addresses, we introduce this struct to be able to fit more Tensor pointers at
|
||||||
|
// a time, currently 400 empirically, compared to the much smaller values in
|
||||||
|
// depth_to_max_tensors. This way, we can launch fewer kernels for better
|
||||||
|
// performance.
|
||||||
|
//
|
||||||
|
// IF YOU USE THIS STRUCT, PLEASE ADD A ONE-OFF TEST IN test_foreach.py AS THIS
|
||||||
|
// IS CURRENTLY ONLY TESTED FOR _foreach_norm.
|
||||||
|
const size_t MAX_TENSORS_PER_KERNEL = 400;
|
||||||
|
struct TensorListAddresses {
|
||||||
|
const void* addresses[MAX_TENSORS_PER_KERNEL];
|
||||||
|
};
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
NormType norm_type,
|
NormType norm_type,
|
||||||
@ -112,7 +130,7 @@ template <
|
|||||||
typename opmath_t = at::opmath_type<T>>
|
typename opmath_t = at::opmath_type<T>>
|
||||||
__global__ void lpnorm_cleanup(
|
__global__ void lpnorm_cleanup(
|
||||||
const opmath_t* output_per_tensor,
|
const opmath_t* output_per_tensor,
|
||||||
T* ret_per_tensor,
|
TensorListAddresses addr_struct,
|
||||||
int max_chunks_per_tensor) {
|
int max_chunks_per_tensor) {
|
||||||
__shared__ opmath_t vals[512];
|
__shared__ opmath_t vals[512];
|
||||||
|
|
||||||
@ -130,7 +148,7 @@ __global__ void lpnorm_cleanup(
|
|||||||
? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
|
? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
|
||||||
: at::native::cuda_utils::BlockReduceMax(val, vals);
|
: at::native::cuda_utils::BlockReduceMax(val, vals);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
ret_per_tensor[blockIdx.x] =
|
*(T*)addr_struct.addresses[blockIdx.x] =
|
||||||
norm_type == NormType::L1 || norm_type == NormType::LInf
|
norm_type == NormType::L1 || norm_type == NormType::LInf
|
||||||
? final_val
|
? final_val
|
||||||
: ::sqrt(final_val);
|
: ::sqrt(final_val);
|
||||||
@ -166,7 +184,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
|||||||
return foreach_tensor_norm_slow(tensors, ord);
|
return foreach_tensor_norm_slow(tensors, ord);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ntensors = tensors.size();
|
const size_t ntensors = tensors.size();
|
||||||
int max_chunks_per_tensor = -1;
|
int max_chunks_per_tensor = -1;
|
||||||
|
|
||||||
for (int t = 0; t < ntensors; t++) {
|
for (int t = 0; t < ntensors; t++) {
|
||||||
@ -178,9 +196,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
|||||||
}
|
}
|
||||||
const auto options = tensors[0].options();
|
const auto options = tensors[0].options();
|
||||||
auto output_per_tensor = at::zeros(
|
auto output_per_tensor = at::zeros(
|
||||||
{ntensors * max_chunks_per_tensor},
|
{static_cast<int64_t>(ntensors) * max_chunks_per_tensor},
|
||||||
options.dtype(toOpMathType(tensors[0].scalar_type())));
|
options.dtype(toOpMathType(tensors[0].scalar_type())));
|
||||||
auto ret_per_tensor = at::empty({ntensors}, options);
|
|
||||||
|
std::vector<at::Tensor> vec_res;
|
||||||
|
vec_res.reserve(ntensors);
|
||||||
|
for (int i = 0; i < ntensors; i++) {
|
||||||
|
vec_res.push_back(at::empty({}, options));
|
||||||
|
}
|
||||||
|
|
||||||
auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
|
auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
|
||||||
if (p == static_cast<double>(1)) {
|
if (p == static_cast<double>(1)) {
|
||||||
@ -200,11 +223,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(
|
const at::cuda::OptionalCUDAGuard device_guard(
|
||||||
device_of(output_per_tensor));
|
device_of(output_per_tensor));
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
|
|
||||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
|
||||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
for (auto i = 0; i < num_kernels; i++) {
|
||||||
|
const size_t num_tensors_this_kernel =
|
||||||
|
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
|
||||||
|
? MAX_TENSORS_PER_KERNEL
|
||||||
|
: (ntensors % MAX_TENSORS_PER_KERNEL);
|
||||||
|
|
||||||
|
TensorListAddresses addr_struct;
|
||||||
|
for (auto j = 0; j < num_tensors_this_kernel; j++) {
|
||||||
|
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
|
||||||
|
.mutable_data_ptr<scalar_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
lpnorm_cleanup<scalar_t, NormType::L1>
|
||||||
|
<<<num_tensors_this_kernel, 512, 0, stream>>>(
|
||||||
|
output_per_tensor.const_data_ptr<opmath_t>() +
|
||||||
|
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
|
||||||
|
addr_struct,
|
||||||
max_chunks_per_tensor);
|
max_chunks_per_tensor);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
} else if (p == static_cast<double>(2)) {
|
} else if (p == static_cast<double>(2)) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
@ -223,11 +263,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(
|
const at::cuda::OptionalCUDAGuard device_guard(
|
||||||
device_of(output_per_tensor));
|
device_of(output_per_tensor));
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
|
|
||||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
|
||||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
for (auto i = 0; i < num_kernels; i++) {
|
||||||
|
const size_t num_tensors_this_kernel =
|
||||||
|
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
|
||||||
|
? MAX_TENSORS_PER_KERNEL
|
||||||
|
: (ntensors % MAX_TENSORS_PER_KERNEL);
|
||||||
|
|
||||||
|
TensorListAddresses addr_struct;
|
||||||
|
for (auto j = 0; j < num_tensors_this_kernel; j++) {
|
||||||
|
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
|
||||||
|
.mutable_data_ptr<scalar_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
lpnorm_cleanup<scalar_t, NormType::L2>
|
||||||
|
<<<num_tensors_this_kernel, 512, 0, stream>>>(
|
||||||
|
output_per_tensor.const_data_ptr<opmath_t>() +
|
||||||
|
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
|
||||||
|
addr_struct,
|
||||||
max_chunks_per_tensor);
|
max_chunks_per_tensor);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
} else if (p == std::numeric_limits<double>::infinity()) {
|
} else if (p == std::numeric_limits<double>::infinity()) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
@ -246,12 +303,28 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(
|
const at::cuda::OptionalCUDAGuard device_guard(
|
||||||
device_of(output_per_tensor));
|
device_of(output_per_tensor));
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
|
||||||
|
for (auto i = 0; i < num_kernels; i++) {
|
||||||
|
const size_t num_tensors_this_kernel =
|
||||||
|
(i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
|
||||||
|
? MAX_TENSORS_PER_KERNEL
|
||||||
|
: (ntensors % MAX_TENSORS_PER_KERNEL);
|
||||||
|
|
||||||
|
TensorListAddresses addr_struct;
|
||||||
|
for (auto j = 0; j < num_tensors_this_kernel; j++) {
|
||||||
|
addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
|
||||||
|
.mutable_data_ptr<scalar_t>();
|
||||||
|
}
|
||||||
|
|
||||||
lpnorm_cleanup<scalar_t, NormType::LInf>
|
lpnorm_cleanup<scalar_t, NormType::LInf>
|
||||||
<<<ntensors, 512, 0, stream>>>(
|
<<<num_tensors_this_kernel, 512, 0, stream>>>(
|
||||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
output_per_tensor.const_data_ptr<opmath_t>() +
|
||||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
|
||||||
|
addr_struct,
|
||||||
max_chunks_per_tensor);
|
max_chunks_per_tensor);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
@ -267,7 +340,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
|||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto& t : tensors) {
|
for (const auto& t : tensors) {
|
||||||
if (t.numel() != 0) {
|
if (t.numel() != 0) {
|
||||||
result.emplace_back(ret_per_tensor[i]);
|
result.emplace_back(vec_res[i]);
|
||||||
i++;
|
i++;
|
||||||
} else {
|
} else {
|
||||||
result.emplace_back(at::zeros({}, options));
|
result.emplace_back(at::zeros({}, options));
|
||||||
|
@ -646,6 +646,20 @@ class TestForeach(TestCase):
|
|||||||
for i, e in enumerate(expect)))
|
for i, e in enumerate(expect)))
|
||||||
self.assertEqual(expect, actual, equal_nan=False)
|
self.assertEqual(expect, actual, equal_nan=False)
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
@ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
|
||||||
|
def test_big_num_tensors(self, device, dtype, op):
|
||||||
|
N = 600
|
||||||
|
tensorlist = [make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False) for _ in range(N)]
|
||||||
|
fn, ref_fn, *_ = self._get_funcs(op)
|
||||||
|
|
||||||
|
import math
|
||||||
|
for ord in (1, 2, math.inf):
|
||||||
|
actual = fn(inputs=[tensorlist], is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
|
||||||
|
expect = ref_fn(inputs=[tensorlist], ord=ord)
|
||||||
|
|
||||||
|
self.assertEqual(expect, actual, equal_nan=True)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@ops(foreach_reduce_op_db)
|
@ops(foreach_reduce_op_db)
|
||||||
def test_foreach_reduce_large_input(self, device, dtype, op):
|
def test_foreach_reduce_large_input(self, device, dtype, op):
|
||||||
|
@ -401,22 +401,24 @@ def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
|
|||||||
if not isinstance(r, torch.Tensor):
|
if not isinstance(r, torch.Tensor):
|
||||||
continue
|
continue
|
||||||
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
|
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
|
||||||
test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
|
test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
|
||||||
test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
|
test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
|
||||||
# See https://github.com/pytorch/pytorch/issues/78050
|
# See https://github.com/pytorch/pytorch/issues/78050
|
||||||
if should_check_strides(func) == CheckStrides.ALL:
|
if should_check_strides(func) == CheckStrides.ALL:
|
||||||
same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
|
same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
|
||||||
test_assert(same_strides, f"but real stride was {r.stride()}")
|
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
|
||||||
elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
|
elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
|
||||||
same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
|
same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
|
||||||
test_assert(same_strides, f"but real stride was {r.stride()}")
|
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
|
||||||
test_assert(
|
test_assert(
|
||||||
meta_r.storage_offset() == r.storage_offset(),
|
meta_r.storage_offset() == r.storage_offset(),
|
||||||
f"but real storage_offset was {r.storage_offset()}")
|
f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
|
||||||
test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}")
|
test_assert(meta_r.requires_grad == r.requires_grad,
|
||||||
|
f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
|
||||||
if func not in CHECK_CONJ_SKIPS:
|
if func not in CHECK_CONJ_SKIPS:
|
||||||
test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}")
|
test_assert(meta_r.is_conj() == r.is_conj(),
|
||||||
test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}")
|
f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
|
||||||
|
test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
|
||||||
|
|
||||||
|
|
||||||
# This environment variable controls whether or not we print expected failure
|
# This environment variable controls whether or not we print expected failure
|
||||||
|
@ -3073,6 +3073,7 @@ def register_meta_foreach(ops):
|
|||||||
aten._foreach_log1p,
|
aten._foreach_log1p,
|
||||||
aten._foreach_log2,
|
aten._foreach_log2,
|
||||||
aten._foreach_neg,
|
aten._foreach_neg,
|
||||||
|
aten._foreach_norm,
|
||||||
aten._foreach_reciprocal,
|
aten._foreach_reciprocal,
|
||||||
aten._foreach_round,
|
aten._foreach_round,
|
||||||
aten._foreach_sigmoid,
|
aten._foreach_sigmoid,
|
||||||
|
@ -9693,11 +9693,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
|||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
|
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
|
||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
|
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
|
||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
|
|
||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
|
|
||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
|
|
||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
|
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
|
||||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
Reference in New Issue
Block a user