[9/N] Fix extra warnings brought by clang-tidy-17 (#139286)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139286
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-10-31 05:20:28 +00:00
committed by PyTorch MergeBot
parent 42b5e191ae
commit f95c71867e
28 changed files with 79 additions and 63 deletions

View File

@ -59,8 +59,11 @@ struct strided_tensor_iter_fixed {
T* data_ = NULL;
int64_t dim_ = 0;
// NOLINTNEXTLINE(*array*)
int64_t counter_[N] = {0};
// NOLINTNEXTLINE(*array*)
int64_t sizes_[N] = {0};
// NOLINTNEXTLINE(*array*)
int64_t strides_[N] = {0};
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;

View File

@ -11,15 +11,15 @@
namespace at {
static cpu_fixed_malloc(void*, ptrdiff_t) {
static void* cpu_fixed_malloc(void*, ptrdiff_t) {
TORCH_CHECK(false, "attempting to resize a tensor view of an external blob");
}
static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
static void* cpu_fixed_realloc(void*, void*, ptrdiff_t) {
TORCH_CHECK(false, "attempting to resize a tensor view of an external blob");
}
static cpu_fixed_free(void* state, void* allocation) {
static void cpu_fixed_free(void* state, void* allocation) {
auto on_release = static_cast<std::function<void(void*)>*>(state);
(*on_release)(allocation);
delete on_release;

View File

@ -256,7 +256,7 @@ Tensor FunctionalInverses::split_with_sizes_inverse(const Tensor& base, const Te
dim = at::maybe_wrap_dim(dim, base.dim());
auto dim_size = base.sym_size(dim);
c10::SymInt start = 0;
for (auto i = 0; i < mutated_view_idx; ++i) {
for (int64_t i = 0; i < mutated_view_idx; ++i) {
start += split_sizes[i];
}
auto end = start + split_sizes[mutated_view_idx];

View File

@ -83,10 +83,10 @@ static c10::SymInt get_nbytes(const Tensor& value) {
if (value.key_set().has(c10::DispatchKey::Python)) {
return value.storage().sym_nbytes();
}
return at::detail::computeStorageNbytes(value.sym_sizes(), value.sym_strides(), value.dtype().itemsize(), value.sym_storage_offset());
return at::detail::computeStorageNbytes(value.sym_sizes(), value.sym_strides(),static_cast<int64_t>(value.dtype().itemsize()), value.sym_storage_offset());
}
// XLA storage objects also do not properly track nbytes.
return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
return static_cast<int64_t>(at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset()));
}
FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)

View File

@ -154,7 +154,7 @@ static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, t
"please file a bug report instead.");
}
batched_tensor_inputs.push_back(tensor);
batched_tensor_inputs_position.push_back(idx);
batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
}
TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
@ -288,7 +288,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
continue;
}
batched_tensor_inputs.push_back(tensor);
batched_tensor_inputs_position.push_back(idx);
batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
}
TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());

View File

@ -25,7 +25,7 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
const auto value_strides = value_.strides();
sizes_and_strides_.resize(public_dims);
for (const auto dim : c10::irange(public_dims)) {
auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
auto actual_dim = actualDim(static_cast<int64_t>(dim), /*wrap_dim=*/false);
sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim);
sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim);
}
@ -37,7 +37,7 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
if (wrap_dim) {
const auto ndim = sizes_and_strides_.size();
dim = maybe_wrap_dim(dim, ndim);
dim = maybe_wrap_dim(dim, static_cast<int64_t>(ndim));
}
auto is_bdim = createBatchDimBitset(bdims_);

View File

@ -366,7 +366,7 @@ Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
}
static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
return maybe_wrap_dim(dim, static_cast<int64_t>(input_sizes.size())) + num_batch_dims;
}
Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {

View File

@ -35,7 +35,7 @@ static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
if (is_bdim[ptr]) {
continue;
}
permutation[idx++] = ptr;
permutation[idx++] = static_cast<int64_t>(ptr);
}
return physical_tensor.permute(permutation);
}
@ -49,7 +49,7 @@ VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logica
}
int64_t VmapPhysicalView::numBatchDims() const {
return levels_.count();
return static_cast<int64_t>(levels_.count());
}
int64_t VmapPhysicalView::numLogicalDims() const {
@ -202,7 +202,7 @@ MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
// batch dims have been moved to the front of the tensor. Any previously
// non-existing batch dims get added to the tensors as new dimensions of size 1.
std::vector<Tensor> physical_tensors;
int64_t num_batch_dims = collective_levels.count();
auto num_batch_dims = collective_levels.count();
for (const auto& logical_tensor : logical_tensors) {
auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
auto physical_tensor = alignBatchDimsAtFront(

View File

@ -21,7 +21,7 @@ ThreadLocalState::ThreadLocalState()
saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()),
saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
for(uint8_t i=0; i<autocast_dtypes_.size(); i++) {
for(size_t i=0; i<autocast_dtypes_.size(); i++) {
autocast_dtypes_[i] = at::autocast::get_autocast_dtype(static_cast<at::DeviceType>(i));
}
#endif
@ -62,7 +62,7 @@ void ThreadLocalState::setThreadLocalState(
at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_);
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER)
for(uint8_t i=0; i<state.autocast_dtypes_.size(); i++) {
for(size_t i=0; i<state.autocast_dtypes_.size(); i++) {
at::autocast::set_autocast_dtype(static_cast<at::DeviceType>(i), state.autocast_dtypes_[i]);
}
#endif

View File

@ -67,14 +67,14 @@ class Operation {
// treat the last N elements of the stack as a list, looking up
// element i
inline IValue& peek(Stack& stack, size_t i, size_t N) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
// NOLINTNEXTLINE(*-narrowing-conversions)
return *(stack.end() - N + i);
}
inline IValue& peek(Stack* stack, size_t i, size_t N) {
return peek(*stack, i, N);
}
inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
// NOLINTNEXTLINE(*-narrowing-conversions)
return *(stack.end() - N + i);
}
inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
@ -96,7 +96,7 @@ inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
return last(*stack, N);
}
inline void drop(Stack& stack, size_t n) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
// NOLINTNEXTLINE(*-narrowing-conversions)
stack.erase(stack.end() - n, stack.end());
}
inline void drop(Stack* stack, size_t n) {

View File

@ -282,6 +282,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
// NOLINTNEXTLINE(bugprone-sizeof-expression)
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};
@ -1750,6 +1751,7 @@ void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) {
}
template <>
// NOLINTNEXTLINE(*array*)
void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasStrsmBatched(
handle,
@ -1768,6 +1770,7 @@ void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)) {
}
template <>
// NOLINTNEXTLINE(*array*)
void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)) {
TORCH_CUDABLAS_CHECK(cublasDtrsmBatched(
handle,
@ -1787,6 +1790,7 @@ void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)) {
template <>
void trsmBatched<c10::complex<float>>(
// NOLINTNEXTLINE(*array*)
CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>)) {
TORCH_CUDABLAS_CHECK(cublasCtrsmBatched(
handle,
@ -1806,6 +1810,7 @@ void trsmBatched<c10::complex<float>>(
template <>
void trsmBatched<c10::complex<double>>(
// NOLINTNEXTLINE(*array*)
CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>)) {
TORCH_CUDABLAS_CHECK(cublasZtrsmBatched(
handle,

View File

@ -33,7 +33,7 @@ void init_p2p_access_cache(int64_t num_devices) {
} // namespace detail
bool get_p2p_access(int dev, int dev_to_access) {
bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
TORCH_CHECK(dev >= 0 || dev < num_devices_,

View File

@ -1,4 +1,5 @@
#include <c10/macros/Macros.h>
#include <c10/core/Device.h>
#include <cstdint>
namespace at::cuda {
@ -6,6 +7,6 @@ namespace detail {
void init_p2p_access_cache(int64_t num_devices);
}
TORCH_CUDA_CPP_API bool get_p2p_access(int source_dev, int dest_dev);
TORCH_CUDA_CPP_API bool get_p2p_access(c10::DeviceIndex source_dev, c10::DeviceIndex dest_dev);
} // namespace at::cuda

View File

@ -59,8 +59,8 @@ void remove_padding_kernelLauncher(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size);
int64_t output_dim,
const int64_t batch_size);
template <typename T>
void remove_padding_transform0213_kernelLauncher(
@ -69,8 +69,8 @@ void remove_padding_transform0213_kernelLauncher(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size);
int64_t output_dim,
const int64_t batch_size);
template <typename T>
void add_padding_kernelLauncher(

View File

@ -1,5 +1,3 @@
#include <numeric>
#include <algorithm>
#include <c10/util/Exception.h>
#include <ATen/ATen.h>
@ -118,7 +116,7 @@ Tensor nested_from_padded_cuda(
}
}
Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) {
static Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) {
int64_t* nt_sizes_ptr = ef_sizes.data_ptr<int64_t>();
int64_t ef_sizes_size_0 = ef_sizes.sizes()[0];
Tensor offsets = at::empty({1 + ef_sizes_size_0}, at::kLong);

View File

@ -154,8 +154,8 @@ void remove_padding_kernelLauncher(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size) {
int64_t output_dim,
const int64_t batch_size) {
dim3 grid;
grid.x = batch_size;
grid.y = GRID_DIM_Y;
@ -188,8 +188,8 @@ void remove_padding_transform0213_kernelLauncher(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size) {
int64_t output_dim,
const int64_t batch_size) {
dim3 grid;
grid.x = batch_size;
grid.y = GRID_DIM_Y;
@ -214,8 +214,8 @@ template void remove_padding_kernelLauncher<float>(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size);
int64_t output_dim,
const int64_t batch_size);
template void remove_padding_kernelLauncher<c10::Half>(
const c10::Half* input,
@ -223,8 +223,8 @@ template void remove_padding_kernelLauncher<c10::Half>(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size);
int64_t output_dim,
const int64_t batch_size);
template void remove_padding_transform0213_kernelLauncher<float>(
const float* input,
@ -232,8 +232,8 @@ template void remove_padding_transform0213_kernelLauncher<float>(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size);
int64_t output_dim,
const int64_t batch_size);
template void remove_padding_transform0213_kernelLauncher<c10::Half>(
const c10::Half* input,
@ -241,8 +241,8 @@ template void remove_padding_transform0213_kernelLauncher<c10::Half>(
const int* offsets,
const int* input_sizes,
const int* output_sizes,
int output_dim,
const int batch_size);
int64_t output_dim,
const int64_t batch_size);
template <typename T>
__global__ void add_padding_1(

View File

@ -89,7 +89,7 @@ int64_t get_nnz(const Tensor& nestedtensor) {
const Tensor& tensor_strides = tensor->get_nested_strides();
const int64_t n_tensors = tensor_strides.size(0);
constexpr int64_t n_dims = 3;
constexpr int n_dims = 3;
// This is safe since head_dim is assured to be consistent
const int64_t num_heads = tensor -> opt_size(2).value();
const int64_t tensor_stride_0 = tensor_strides.stride(0);

View File

@ -114,7 +114,7 @@ static PyObject* THPEvent_record(
auto stream = (THPStream*)_stream;
self->event.record(c10::Stream::unpack3(
stream->stream_id,
stream->device_index,
static_cast<c10::DeviceIndex>(stream->device_index),
static_cast<c10::DeviceType>(stream->device_type)));
} else {
c10::impl::VirtualGuardImpl impl{
@ -192,7 +192,7 @@ static PyObject* THPEvent_wait(
auto stream = (THPStream*)_stream;
self->event.block(c10::Stream::unpack3(
stream->stream_id,
stream->device_index,
static_cast<c10::DeviceIndex>(stream->device_index),
static_cast<c10::DeviceType>(stream->device_type)));
} else {
c10::impl::VirtualGuardImpl impl{

View File

@ -326,7 +326,7 @@ static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) {
static PyObject* THPModule_getNumInteropThreads(
PyObject* module,
PyObject* noargs) {
return THPUtils_packInt32(at::get_num_interop_threads());
return THPUtils_packUInt64(at::get_num_interop_threads());
}
static PyObject* THPModule_setNumInteropThreads(

View File

@ -46,7 +46,7 @@ using PyModuleClass =
/// to which it delegates all calls.
template <typename ModuleType>
void bind_cpp_module_wrapper(
py::module module,
const py::module& module,
PyModuleClass<ModuleType> cpp_class,
const char* name) {
// Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass

View File

@ -1280,6 +1280,7 @@ PyObject* THPModule_increment_version(
}
// autograd methods on torch._C
// NOLINTNEXTLINE(*array*)
static PyMethodDef methods[] = {
{"_set_grad_enabled",
castPyCFunctionWithKeywords(set_grad_enabled),

View File

@ -81,7 +81,7 @@ inline PyObject* wrap(at::QScheme qscheme) {
}
inline PyObject* wrap(at::TensorList tl) {
auto r = THPObjectPtr{PyTuple_New(tl.size())};
auto r = THPObjectPtr{PyTuple_New(static_cast<Py_ssize_t>(tl.size()))};
if (!r)
throw python_error();
for (const auto i : c10::irange(tl.size())) {
@ -91,7 +91,7 @@ inline PyObject* wrap(at::TensorList tl) {
}
inline PyObject* wrap(at::IntArrayRef list) {
auto r = THPObjectPtr{PyTuple_New(list.size())};
auto r = THPObjectPtr{PyTuple_New(static_cast<Py_ssize_t>(list.size()))};
if (!r)
throw python_error();
for (const auto i : c10::irange(list.size())) {

View File

@ -358,10 +358,12 @@ struct TORCH_API ViewFunc {
/// Sets the values of any SymInts in the saved state. The input vector size
/// must match the number of SymInts in the saved state (i.e. the size of the
/// list returned by get_symints()).
/// NOLINTNEXTLINE(performance-unnecessary-value-param)
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size
/// must match the number of Tensors in the saved state (i.e. the size of the
/// list returned by get_tensors()).
/// NOLINTNEXTLINE(performance-unnecessary-value-param)
virtual void set_tensors(std::vector<at::Tensor>) {}
};

View File

@ -1253,6 +1253,7 @@ static void registerCudaPluggableAllocator(PyObject* module) {
m.def(
"_set_storage_data_ptr_access_error_msg",
[](size_t storage_impl_ptr, std::string s) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
storage_impl->release_data_and_set_meta_custom_data_ptr_error_msg_(s);
});

View File

@ -282,8 +282,7 @@ std::vector<at::Tensor>& scatter_out(
at::cuda::OptionalCUDAStreamGuard cuda_guard;
for (const auto i : c10::irange(chunks.size())) {
if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
const auto device_index =
static_cast<int16_t>(out_tensors[i].get_device());
const auto device_index = out_tensors[i].get_device();
TORCH_CHECK(
(*streams)[i]->device_index() == device_index,
"Expected the device associated with the stream at index ",
@ -293,7 +292,7 @@ std::vector<at::Tensor>& scatter_out(
") ",
"to match the device supplied at that index ",
"(expected ",
device_index,
static_cast<int16_t>(device_index),
")");
cuda_guard.reset_stream(*(*streams)[i]);
}

View File

@ -109,6 +109,7 @@ ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
return ncclDataType_t::ncclInt;
case at::kChar:
return ncclDataType_t::ncclChar;
// NOLINTNEXTLINE(*-narrowing-conversions)
case at::kByte:
return ncclDataType_t::ncclUint8;
case at::kBool:
@ -260,8 +261,9 @@ void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
}
struct NcclCommList {
// NOLINTNEXTLINE(*array*)
std::unique_ptr<ncclComm_t[]> comms;
int ndevices;
size_t ndevices;
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
NCCL_CHECK(ncclCommInitAll(
@ -309,8 +311,8 @@ ArrayRef<ncclComm_t> get_communicators(TensorList inputs) {
static inline void check_tensor(
const at::Tensor& input,
const std::optional<at::Tensor>& output,
int input_multiplier,
int output_multiplier,
size_t input_multiplier,
size_t output_multiplier,
int64_t ref_numel,
ScalarType ref_dtype) {
auto check_one = [&](const at::Tensor& tensor) {
@ -355,12 +357,12 @@ static inline void check_tensor(
void check_inputs(
TensorList inputs,
TensorList outputs,
int input_multiplier,
int output_multiplier) {
size_t input_multiplier,
size_t output_multiplier) {
// len(inputs) == len(outputs)
size_t len = inputs.size();
if (len <= 0) {
if (len == 0) {
throw std::runtime_error("input sequence can't be empty");
}
@ -967,7 +969,7 @@ void all2all(
uintptr_t recvBase = reinterpret_cast<uintptr_t>(outputTensors[0].data_ptr());
size_t dtypeSize = inputTensors.front().element_size();
for (const auto r : c10::irange(outputTensors.size())) {
for (const int r : c10::irange(outputTensors.size())) {
sendCounts[r] = inputTensors[r].numel();
auto sendOffset =
reinterpret_cast<uintptr_t>(inputTensors[r].data_ptr()) - sendBase;
@ -995,7 +997,7 @@ void all2all(
stream.stream()));
#else
NCCL_CHECK(ncclGroupStart());
for (const auto r : c10::irange(outputTensors.size())) {
for (const int r : c10::irange(static_cast<int>(outputTensors.size()))) {
at::Tensor& input = inputTensors[r];
at::Tensor& output = outputTensors[r];

View File

@ -32,7 +32,7 @@ typedef void* ncclComm_t;
* nccl impp. */
#define NCCL_UNIQUE_ID_BYTES 128
typedef struct {
// NOLINTNEXTLINE(*array)
// NOLINTNEXTLINE(*array*)
char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;
@ -100,14 +100,14 @@ TORCH_CUDA_CPP_API at::ArrayRef<ncclComm_t> get_communicators(
TORCH_CUDA_CPP_API void check_inputs(
at::TensorList inputs,
at::TensorList outputs,
int input_multiplier,
int output_multiplier);
size_t input_multiplier,
size_t output_multiplier);
TORCH_CUDA_CPP_API void check_inputs(
at::TensorList inputs,
const at::Tensor& output,
int root,
int input_multiplier,
int output_multiplier);
size_t input_multiplier,
size_t output_multiplier);
} // namespace detail

View File

@ -73,6 +73,7 @@ void initCudartBindings(PyObject* module) {
[](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t {
py::gil_scoped_release no_gil;
return C10_CUDA_ERROR_HANDLED(
// NOLINTNEXTLINE(performance-no-int-to-ptr)
cudaHostRegister((void*)ptr, size, flags));
});
cudart.def(
@ -80,6 +81,7 @@ void initCudartBindings(PyObject* module) {
"HostUnregister",
[](uintptr_t ptr) -> cudaError_t {
py::gil_scoped_release no_gil;
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(cudaHostUnregister((void*)ptr));
});
cudart.def(
@ -87,6 +89,7 @@ void initCudartBindings(PyObject* module) {
"StreamCreate",
[](uintptr_t ptr) -> cudaError_t {
py::gil_scoped_release no_gil;
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(cudaStreamCreate((cudaStream_t*)ptr));
});
cudart.def(
@ -94,6 +97,7 @@ void initCudartBindings(PyObject* module) {
"StreamDestroy",
[](uintptr_t ptr) -> cudaError_t {
py::gil_scoped_release no_gil;
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(cudaStreamDestroy((cudaStream_t)ptr));
});
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12000