[BE] use DeviceIndex instead of int64_t for related device interfaces (#103068)

This PR unifies the device interfaces in aten/*cpp and torch/csrc/*cpp to use  **c10::DeviceIndex**.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103068
Approved by: https://github.com/malfet
This commit is contained in:
cyy
2023-08-25 20:16:11 +00:00
committed by PyTorch MergeBot
parent 4656e09431
commit d9fb7166d6
18 changed files with 73 additions and 56 deletions

View File

@ -82,6 +82,7 @@ namespace impl {
c10::QScheme,
c10::ScalarType,
c10::Device,
c10::DeviceIndex,
c10::Layout,
c10::MemoryFormat,
at::Dimname

View File

@ -11,6 +11,7 @@
#include <c10/util/Optional.h>
#include <c10/core/SymFloat.h>
#include <c10/core/SymBool.h>
#include <c10/core/Device.h>
#include <array>
#include <memory>
@ -1869,6 +1870,13 @@ struct getTypePtr_<int64_t> final {
}
};
template <>
struct getTypePtr_<DeviceIndex> final {
static decltype(auto) call() {
return IntType::get();
}
};
template <>
struct getMaybeFakeTypePtr_<SymInt, false> final {
static decltype(auto) call() {

View File

@ -38,8 +38,8 @@ constexpr int checkStaticTypes() {
// Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(guts::conjunction<
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
>::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
>::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
static_assert(guts::conjunction<
bool_t<!std::is_same<Types, float>::value>...
>::value, "INVALID TYPE: float is not supported as an argument type, use double instead");

View File

@ -42,13 +42,13 @@
#include <memory>
namespace c10::cuda::_internal {
void setHasPrimaryContext(bool (*func)(int64_t));
void setHasPrimaryContext(bool (*func)(DeviceIndex));
}
namespace at::cuda::detail {
const at::cuda::NVRTC& nvrtc();
int64_t current_device();
DeviceIndex current_device();
static void (*magma_init_fn)() = nullptr;
@ -57,7 +57,7 @@ void set_magma_init_fn(void (*fn)()) {
}
namespace {
bool _hasPrimaryContext(int64_t device_index) {
bool _hasPrimaryContext(DeviceIndex device_index) {
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
unsigned int ctx_flags;
@ -226,7 +226,7 @@ const at::cuda::NVRTC& CUDAHooks::nvrtc() const {
return at::cuda::detail::nvrtc();
}
int64_t current_device() {
DeviceIndex current_device() {
int device;
cudaError_t err = c10::cuda::GetDevice(&device);
if (err == cudaSuccess) {
@ -235,11 +235,11 @@ int64_t current_device() {
return -1;
}
int64_t CUDAHooks::current_device() const {
DeviceIndex CUDAHooks::current_device() const {
return at::cuda::detail::current_device();
}
bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
bool CUDAHooks::hasPrimaryContext(DeviceIndex device_index) const {
return _hasPrimaryContext(device_index);
}
@ -414,19 +414,19 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const {
#endif
}
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const {
return at::native::detail::cufft_get_plan_cache_max_size_impl(device_index);
}
void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
void CUDAHooks::cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const {
at::native::detail::cufft_set_plan_cache_max_size_impl(device_index, max_size);
}
int64_t CUDAHooks::cuFFTGetPlanCacheSize(int64_t device_index) const {
int64_t CUDAHooks::cuFFTGetPlanCacheSize(DeviceIndex device_index) const {
return at::native::detail::cufft_get_plan_cache_size_impl(device_index);
}
void CUDAHooks::cuFFTClearPlanCache(int64_t device_index) const {
void CUDAHooks::cuFFTClearPlanCache(DeviceIndex device_index) const {
at::native::detail::cufft_clear_plan_cache_impl(device_index);
}
@ -434,7 +434,7 @@ int CUDAHooks::getNumGPUs() const {
return at::cuda::device_count();
}
void CUDAHooks::deviceSynchronize(int64_t device_index) const {
void CUDAHooks::deviceSynchronize(DeviceIndex device_index) const {
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
c10::cuda::device_synchronize();
}

View File

@ -29,8 +29,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCuSOLVER() const override;
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override;
int64_t current_device() const override;
bool hasPrimaryContext(int64_t device_index) const override;
DeviceIndex current_device() const override;
bool hasPrimaryContext(DeviceIndex device_index) const override;
Allocator* getCUDADeviceAllocator() const override;
Allocator* getPinnedMemoryAllocator() const override;
bool compiledWithCuDNN() const override;
@ -43,12 +43,12 @@ struct CUDAHooks : public at::CUDAHooksInterface {
long versionCuDNN() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;
int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override;
void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override;
int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
void cuFFTClearPlanCache(int64_t device_index) const override;
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
int getNumGPUs() const override;
void deviceSynchronize(int64_t device_index) const override;
void deviceSynchronize(DeviceIndex device_index) const override;
};
}}} // at::cuda::detail

View File

@ -114,11 +114,11 @@ struct TORCH_API CUDAHooksInterface {
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
}
virtual bool hasPrimaryContext(int64_t device_index) const {
virtual bool hasPrimaryContext(DeviceIndex device_index) const {
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
}
virtual int64_t current_device() const {
virtual DeviceIndex current_device() const {
return -1;
}
@ -167,19 +167,19 @@ struct TORCH_API CUDAHooksInterface {
"Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
}
virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t /*device_index*/) const {
virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
virtual void cuFFTSetPlanCacheMaxSize(int64_t /*device_index*/, int64_t /*max_size*/) const {
virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
virtual int64_t cuFFTGetPlanCacheSize(int64_t /*device_index*/) const {
virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
virtual void cuFFTClearPlanCache(int64_t /*device_index*/) const {
virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
@ -187,7 +187,7 @@ struct TORCH_API CUDAHooksInterface {
return 0;
}
virtual void deviceSynchronize(int64_t /*device_index*/) const {
virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
}
};

View File

@ -793,19 +793,19 @@ Tensor fft_ifftshift(const Tensor& x, at::OptionalIntArrayRef dim_opt) {
// We call the following methods via CUDA hooks because they are really only
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
int64_t _cufft_get_plan_cache_max_size(int64_t device_index) {
int64_t _cufft_get_plan_cache_max_size(DeviceIndex device_index) {
return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize(device_index);
}
void _cufft_set_plan_cache_max_size(int64_t device_index, int64_t max_size) {
void _cufft_set_plan_cache_max_size(DeviceIndex device_index, int64_t max_size) {
detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(device_index, max_size);
}
int64_t _cufft_get_plan_cache_size(int64_t device_index) {
int64_t _cufft_get_plan_cache_size(DeviceIndex device_index) {
return detail::getCUDAHooks().cuFFTGetPlanCacheSize(device_index);
}
void _cufft_clear_plan_cache(int64_t device_index) {
void _cufft_clear_plan_cache(DeviceIndex device_index) {
detail::getCUDAHooks().cuFFTClearPlanCache(device_index);
}

View File

@ -524,9 +524,9 @@ private:
// native function counterparts (at native/SpectralOps.cpp), i.e.,
// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index);
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size);
int64_t cufft_get_plan_cache_size_impl(int64_t device_index);
void cufft_clear_plan_cache_impl(int64_t device_index);
int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index);
void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size);
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index);
void cufft_clear_plan_cache_impl(DeviceIndex device_index);
}}} // namespace at::native::detail

View File

@ -133,7 +133,7 @@ static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;
static std::mutex plan_caches_mutex;
static inline
CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
CuFFTParamsLRUCache &cufft_get_plan_cache(DeviceIndex device_index) {
std::lock_guard<std::mutex> guard(plan_caches_mutex);
AT_ASSERT(device_index >= 0);
@ -152,7 +152,7 @@ CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
namespace detail {
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
@ -160,7 +160,7 @@ int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
return cufft_get_plan_cache(device_index).max_size();
}
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) {
void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
@ -168,7 +168,7 @@ void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size)
return cufft_get_plan_cache(device_index).resize(max_size);
}
int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_get_plan_cache_size: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
@ -176,7 +176,7 @@ int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
return cufft_get_plan_cache(device_index).size();
}
void cufft_clear_plan_cache_impl(int64_t device_index) {
void cufft_clear_plan_cache_impl(DeviceIndex device_index) {
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
"cufft_clear_plan_cache: expected 0 <= device_index < ",
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",

View File

@ -2945,13 +2945,13 @@
CPU: _validate_compressed_sparse_indices_cpu
CUDA: _validate_compressed_sparse_indices_cuda
- func: _cufft_get_plan_cache_size(int device_index) -> int
- func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int
- func: _cufft_get_plan_cache_max_size(int device_index) -> int
- func: _cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> ()
- func: _cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()
- func: _cufft_clear_plan_cache(int device_index) -> ()
- func: _cufft_clear_plan_cache(DeviceIndex device_index) -> ()
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
device_check: NoCheck # TensorIterator

View File

@ -174,13 +174,13 @@ struct C10_API Device final {
// This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ == -1 || index_ >= 0,
index_ >= -1,
"Device index must be -1 or non-negative, got ",
(int)index_);
static_cast<int>(index_));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
(int)index_);
static_cast<int>(index_));
}
};

View File

@ -148,9 +148,9 @@ void warn_or_error_on_sync() {
}
}
c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
c10::optional<DeviceIndex> getDeviceIndexWithPrimaryContext() {
// check current device first
int64_t current_device_index = current_device();
auto current_device_index = current_device();
if (current_device_index >= 0) {
if (hasPrimaryContext(current_device_index)) {
return current_device_index;
@ -167,18 +167,18 @@ c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
}
namespace _internal {
bool dummyHasPrimaryContext(C10_UNUSED int64_t device_index) {
bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) {
TORCH_CHECK(false, "Should never been called");
}
bool (*hasPrimaryContext)(int64_t) = dummyHasPrimaryContext;
bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext;
// Private api to be called from CUDAHooks.cpp
C10_CUDA_API void setHasPrimaryContext(bool (*func)(int64_t)) {
C10_CUDA_API void setHasPrimaryContext(bool (*func)(DeviceIndex)) {
hasPrimaryContext = func ? func : dummyHasPrimaryContext;
}
} // namespace _internal
bool hasPrimaryContext(int64_t device_index) {
bool hasPrimaryContext(DeviceIndex device_index) {
return _internal::hasPrimaryContext(device_index);
}

View File

@ -111,8 +111,8 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}
C10_CUDA_API bool hasPrimaryContext(int64_t device_index);
C10_CUDA_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
C10_CUDA_API c10::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
} // namespace cuda
} // namespace c10

View File

@ -58,6 +58,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
// use the custom class mechanism
// instead. @jerryzh
{"Device", c10::TypeFactory::get<DeviceObjType>()},
{"DeviceIndex", c10::TypeFactory::get<IntType>()},
{"Stream", c10::TypeFactory::get<StreamObjType>()},
{"Scalar", c10::TypeFactory::get<NumberType>()},
{"str", c10::TypeFactory::get<StringType>()},

View File

@ -43,6 +43,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
{"QScheme", ParameterType::QSCHEME},
{"Device", ParameterType::DEVICE},
{"DeviceIndex", ParameterType::INT64},
{"Stream", ParameterType::STREAM},
{"std::string", ParameterType::STRING},
{"c10::string_view", ParameterType::STRING},

View File

@ -661,6 +661,7 @@ def argument_type_str(
BaseTy.Storage,
BaseTy.Layout,
BaseTy.Device,
BaseTy.DeviceIndex,
BaseTy.MemoryFormat,
BaseTy.Dimname,
BaseTy.Stream,
@ -907,7 +908,7 @@ def argument_type_str_pyi(t: Type) -> str:
add_optional = True
if isinstance(t, BaseType):
if t.name == BaseTy.int:
if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
ret = "_int"
if t.name == BaseTy.SymInt:
ret = "Union[_int, SymInt]"
@ -1255,6 +1256,8 @@ def arg_parser_unpack_method(
return "scalartypeWithDefault" if has_default_init else "scalartype"
elif t.name == BaseTy.Device:
return "deviceWithDefault" if has_default_init else "device"
elif t.name == BaseTy.DeviceIndex:
return "toInt64"
elif t.name == BaseTy.int:
return "toInt64"
elif t.name == BaseTy.SymInt:

View File

@ -62,6 +62,7 @@ dimnameListT = BaseCppType("at", "DimnameList")
dimVectorT = BaseCppType("at", "DimVector")
layoutT = BaseCppType("at", "Layout")
deviceT = BaseCppType("at", "Device")
deviceIndexT = BaseCppType("at", "DeviceIndex")
scalarT = BaseCppType("at", "Scalar")
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
memoryFormatT = BaseCppType("at", "MemoryFormat")
@ -111,6 +112,7 @@ BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.DimVector: dimVectorT,
BaseTy.Layout: layoutT,
BaseTy.Device: deviceT,
BaseTy.DeviceIndex: deviceIndexT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
BaseTy.QScheme: qschemeT,

View File

@ -1803,6 +1803,7 @@ class BaseTy(Enum):
bool = auto()
Layout = auto()
Device = auto()
DeviceIndex = auto()
Scalar = auto()
MemoryFormat = auto()
QScheme = auto()