mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
[BE] use DeviceIndex instead of int64_t for related device interfaces (#103068)
This PR unifies the device interfaces in aten/*cpp and torch/csrc/*cpp to use **c10::DeviceIndex**. Pull Request resolved: https://github.com/pytorch/pytorch/pull/103068 Approved by: https://github.com/malfet
This commit is contained in:
@ -82,6 +82,7 @@ namespace impl {
|
||||
c10::QScheme,
|
||||
c10::ScalarType,
|
||||
c10::Device,
|
||||
c10::DeviceIndex,
|
||||
c10::Layout,
|
||||
c10::MemoryFormat,
|
||||
at::Dimname
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/Device.h>
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
@ -1869,6 +1870,13 @@ struct getTypePtr_<int64_t> final {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct getTypePtr_<DeviceIndex> final {
|
||||
static decltype(auto) call() {
|
||||
return IntType::get();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct getMaybeFakeTypePtr_<SymInt, false> final {
|
||||
static decltype(auto) call() {
|
||||
|
||||
@ -38,8 +38,8 @@ constexpr int checkStaticTypes() {
|
||||
// Give nice error messages for some of the common error cases.
|
||||
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
|
||||
static_assert(guts::conjunction<
|
||||
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
|
||||
>::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
|
||||
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
|
||||
>::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
|
||||
static_assert(guts::conjunction<
|
||||
bool_t<!std::is_same<Types, float>::value>...
|
||||
>::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
|
||||
|
||||
@ -42,13 +42,13 @@
|
||||
#include <memory>
|
||||
|
||||
namespace c10::cuda::_internal {
|
||||
void setHasPrimaryContext(bool (*func)(int64_t));
|
||||
void setHasPrimaryContext(bool (*func)(DeviceIndex));
|
||||
}
|
||||
|
||||
namespace at::cuda::detail {
|
||||
|
||||
const at::cuda::NVRTC& nvrtc();
|
||||
int64_t current_device();
|
||||
DeviceIndex current_device();
|
||||
|
||||
static void (*magma_init_fn)() = nullptr;
|
||||
|
||||
@ -57,7 +57,7 @@ void set_magma_init_fn(void (*fn)()) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool _hasPrimaryContext(int64_t device_index) {
|
||||
bool _hasPrimaryContext(DeviceIndex device_index) {
|
||||
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
|
||||
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
|
||||
unsigned int ctx_flags;
|
||||
@ -226,7 +226,7 @@ const at::cuda::NVRTC& CUDAHooks::nvrtc() const {
|
||||
return at::cuda::detail::nvrtc();
|
||||
}
|
||||
|
||||
int64_t current_device() {
|
||||
DeviceIndex current_device() {
|
||||
int device;
|
||||
cudaError_t err = c10::cuda::GetDevice(&device);
|
||||
if (err == cudaSuccess) {
|
||||
@ -235,11 +235,11 @@ int64_t current_device() {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64_t CUDAHooks::current_device() const {
|
||||
DeviceIndex CUDAHooks::current_device() const {
|
||||
return at::cuda::detail::current_device();
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
|
||||
bool CUDAHooks::hasPrimaryContext(DeviceIndex device_index) const {
|
||||
return _hasPrimaryContext(device_index);
|
||||
}
|
||||
|
||||
@ -414,19 +414,19 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
|
||||
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const {
|
||||
return at::native::detail::cufft_get_plan_cache_max_size_impl(device_index);
|
||||
}
|
||||
|
||||
void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
|
||||
void CUDAHooks::cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const {
|
||||
at::native::detail::cufft_set_plan_cache_max_size_impl(device_index, max_size);
|
||||
}
|
||||
|
||||
int64_t CUDAHooks::cuFFTGetPlanCacheSize(int64_t device_index) const {
|
||||
int64_t CUDAHooks::cuFFTGetPlanCacheSize(DeviceIndex device_index) const {
|
||||
return at::native::detail::cufft_get_plan_cache_size_impl(device_index);
|
||||
}
|
||||
|
||||
void CUDAHooks::cuFFTClearPlanCache(int64_t device_index) const {
|
||||
void CUDAHooks::cuFFTClearPlanCache(DeviceIndex device_index) const {
|
||||
at::native::detail::cufft_clear_plan_cache_impl(device_index);
|
||||
}
|
||||
|
||||
@ -434,7 +434,7 @@ int CUDAHooks::getNumGPUs() const {
|
||||
return at::cuda::device_count();
|
||||
}
|
||||
|
||||
void CUDAHooks::deviceSynchronize(int64_t device_index) const {
|
||||
void CUDAHooks::deviceSynchronize(DeviceIndex device_index) const {
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
c10::cuda::device_synchronize();
|
||||
}
|
||||
|
||||
@ -29,8 +29,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCuSOLVER() const override;
|
||||
bool hasROCM() const override;
|
||||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
int64_t current_device() const override;
|
||||
bool hasPrimaryContext(int64_t device_index) const override;
|
||||
DeviceIndex current_device() const override;
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
||||
Allocator* getCUDADeviceAllocator() const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
bool compiledWithCuDNN() const override;
|
||||
@ -43,12 +43,12 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
long versionCuDNN() const override;
|
||||
std::string showConfig() const override;
|
||||
double batchnormMinEpsilonCuDNN() const override;
|
||||
int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override;
|
||||
void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override;
|
||||
int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
|
||||
void cuFFTClearPlanCache(int64_t device_index) const override;
|
||||
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
|
||||
void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
|
||||
int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
|
||||
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
|
||||
int getNumGPUs() const override;
|
||||
void deviceSynchronize(int64_t device_index) const override;
|
||||
void deviceSynchronize(DeviceIndex device_index) const override;
|
||||
};
|
||||
|
||||
}}} // at::cuda::detail
|
||||
|
||||
@ -114,11 +114,11 @@ struct TORCH_API CUDAHooksInterface {
|
||||
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasPrimaryContext(int64_t device_index) const {
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const {
|
||||
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int64_t current_device() const {
|
||||
virtual DeviceIndex current_device() const {
|
||||
return -1;
|
||||
}
|
||||
|
||||
@ -167,19 +167,19 @@ struct TORCH_API CUDAHooksInterface {
|
||||
"Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t /*device_index*/) const {
|
||||
virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const {
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual void cuFFTSetPlanCacheMaxSize(int64_t /*device_index*/, int64_t /*max_size*/) const {
|
||||
virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const {
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int64_t cuFFTGetPlanCacheSize(int64_t /*device_index*/) const {
|
||||
virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const {
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual void cuFFTClearPlanCache(int64_t /*device_index*/) const {
|
||||
virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const {
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
@ -187,7 +187,7 @@ struct TORCH_API CUDAHooksInterface {
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual void deviceSynchronize(int64_t /*device_index*/) const {
|
||||
virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
|
||||
TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
};
|
||||
|
||||
@ -793,19 +793,19 @@ Tensor fft_ifftshift(const Tensor& x, at::OptionalIntArrayRef dim_opt) {
|
||||
|
||||
// We call the following methods via CUDA hooks because they are really only
|
||||
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
|
||||
int64_t _cufft_get_plan_cache_max_size(int64_t device_index) {
|
||||
int64_t _cufft_get_plan_cache_max_size(DeviceIndex device_index) {
|
||||
return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize(device_index);
|
||||
}
|
||||
|
||||
void _cufft_set_plan_cache_max_size(int64_t device_index, int64_t max_size) {
|
||||
void _cufft_set_plan_cache_max_size(DeviceIndex device_index, int64_t max_size) {
|
||||
detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(device_index, max_size);
|
||||
}
|
||||
|
||||
int64_t _cufft_get_plan_cache_size(int64_t device_index) {
|
||||
int64_t _cufft_get_plan_cache_size(DeviceIndex device_index) {
|
||||
return detail::getCUDAHooks().cuFFTGetPlanCacheSize(device_index);
|
||||
}
|
||||
|
||||
void _cufft_clear_plan_cache(int64_t device_index) {
|
||||
void _cufft_clear_plan_cache(DeviceIndex device_index) {
|
||||
detail::getCUDAHooks().cuFFTClearPlanCache(device_index);
|
||||
}
|
||||
|
||||
|
||||
@ -524,9 +524,9 @@ private:
|
||||
// native function counterparts (at native/SpectralOps.cpp), i.e.,
|
||||
// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
|
||||
// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
|
||||
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index);
|
||||
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size);
|
||||
int64_t cufft_get_plan_cache_size_impl(int64_t device_index);
|
||||
void cufft_clear_plan_cache_impl(int64_t device_index);
|
||||
int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index);
|
||||
void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size);
|
||||
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index);
|
||||
void cufft_clear_plan_cache_impl(DeviceIndex device_index);
|
||||
|
||||
}}} // namespace at::native::detail
|
||||
|
||||
@ -133,7 +133,7 @@ static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;
|
||||
static std::mutex plan_caches_mutex;
|
||||
|
||||
static inline
|
||||
CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
|
||||
CuFFTParamsLRUCache &cufft_get_plan_cache(DeviceIndex device_index) {
|
||||
std::lock_guard<std::mutex> guard(plan_caches_mutex);
|
||||
|
||||
AT_ASSERT(device_index >= 0);
|
||||
@ -152,7 +152,7 @@ CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
|
||||
|
||||
namespace detail {
|
||||
|
||||
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
|
||||
int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index) {
|
||||
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
|
||||
"cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
|
||||
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
|
||||
@ -160,7 +160,7 @@ int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
|
||||
return cufft_get_plan_cache(device_index).max_size();
|
||||
}
|
||||
|
||||
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) {
|
||||
void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size) {
|
||||
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
|
||||
"cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
|
||||
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
|
||||
@ -168,7 +168,7 @@ void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size)
|
||||
return cufft_get_plan_cache(device_index).resize(max_size);
|
||||
}
|
||||
|
||||
int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
|
||||
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index) {
|
||||
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
|
||||
"cufft_get_plan_cache_size: expected 0 <= device_index < ",
|
||||
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
|
||||
@ -176,7 +176,7 @@ int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
|
||||
return cufft_get_plan_cache(device_index).size();
|
||||
}
|
||||
|
||||
void cufft_clear_plan_cache_impl(int64_t device_index) {
|
||||
void cufft_clear_plan_cache_impl(DeviceIndex device_index) {
|
||||
TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
|
||||
"cufft_clear_plan_cache: expected 0 <= device_index < ",
|
||||
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
|
||||
|
||||
@ -2945,13 +2945,13 @@
|
||||
CPU: _validate_compressed_sparse_indices_cpu
|
||||
CUDA: _validate_compressed_sparse_indices_cuda
|
||||
|
||||
- func: _cufft_get_plan_cache_size(int device_index) -> int
|
||||
- func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int
|
||||
|
||||
- func: _cufft_get_plan_cache_max_size(int device_index) -> int
|
||||
- func: _cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int
|
||||
|
||||
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> ()
|
||||
- func: _cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()
|
||||
|
||||
- func: _cufft_clear_plan_cache(int device_index) -> ()
|
||||
- func: _cufft_clear_plan_cache(DeviceIndex device_index) -> ()
|
||||
|
||||
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
||||
@ -174,13 +174,13 @@ struct C10_API Device final {
|
||||
// This is safe to do, because backends that use the DeviceIndex
|
||||
// have a later check when we actually try to switch to that device.
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
index_ == -1 || index_ >= 0,
|
||||
index_ >= -1,
|
||||
"Device index must be -1 or non-negative, got ",
|
||||
(int)index_);
|
||||
static_cast<int>(index_));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!is_cpu() || index_ <= 0,
|
||||
"CPU device index must be -1 or zero, got ",
|
||||
(int)index_);
|
||||
static_cast<int>(index_));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -148,9 +148,9 @@ void warn_or_error_on_sync() {
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
|
||||
c10::optional<DeviceIndex> getDeviceIndexWithPrimaryContext() {
|
||||
// check current device first
|
||||
int64_t current_device_index = current_device();
|
||||
auto current_device_index = current_device();
|
||||
if (current_device_index >= 0) {
|
||||
if (hasPrimaryContext(current_device_index)) {
|
||||
return current_device_index;
|
||||
@ -167,18 +167,18 @@ c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
|
||||
}
|
||||
|
||||
namespace _internal {
|
||||
bool dummyHasPrimaryContext(C10_UNUSED int64_t device_index) {
|
||||
bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) {
|
||||
TORCH_CHECK(false, "Should never been called");
|
||||
}
|
||||
bool (*hasPrimaryContext)(int64_t) = dummyHasPrimaryContext;
|
||||
bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext;
|
||||
|
||||
// Private api to be called from CUDAHooks.cpp
|
||||
C10_CUDA_API void setHasPrimaryContext(bool (*func)(int64_t)) {
|
||||
C10_CUDA_API void setHasPrimaryContext(bool (*func)(DeviceIndex)) {
|
||||
hasPrimaryContext = func ? func : dummyHasPrimaryContext;
|
||||
}
|
||||
} // namespace _internal
|
||||
|
||||
bool hasPrimaryContext(int64_t device_index) {
|
||||
bool hasPrimaryContext(DeviceIndex device_index) {
|
||||
return _internal::hasPrimaryContext(device_index);
|
||||
}
|
||||
|
||||
|
||||
@ -111,8 +111,8 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
C10_CUDA_API bool hasPrimaryContext(int64_t device_index);
|
||||
C10_CUDA_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
|
||||
C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
|
||||
C10_CUDA_API c10::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
||||
@ -58,6 +58,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
||||
// use the custom class mechanism
|
||||
// instead. @jerryzh
|
||||
{"Device", c10::TypeFactory::get<DeviceObjType>()},
|
||||
{"DeviceIndex", c10::TypeFactory::get<IntType>()},
|
||||
{"Stream", c10::TypeFactory::get<StreamObjType>()},
|
||||
{"Scalar", c10::TypeFactory::get<NumberType>()},
|
||||
{"str", c10::TypeFactory::get<StringType>()},
|
||||
|
||||
@ -43,6 +43,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
|
||||
{"QScheme", ParameterType::QSCHEME},
|
||||
{"Device", ParameterType::DEVICE},
|
||||
{"DeviceIndex", ParameterType::INT64},
|
||||
{"Stream", ParameterType::STREAM},
|
||||
{"std::string", ParameterType::STRING},
|
||||
{"c10::string_view", ParameterType::STRING},
|
||||
|
||||
@ -661,6 +661,7 @@ def argument_type_str(
|
||||
BaseTy.Storage,
|
||||
BaseTy.Layout,
|
||||
BaseTy.Device,
|
||||
BaseTy.DeviceIndex,
|
||||
BaseTy.MemoryFormat,
|
||||
BaseTy.Dimname,
|
||||
BaseTy.Stream,
|
||||
@ -907,7 +908,7 @@ def argument_type_str_pyi(t: Type) -> str:
|
||||
add_optional = True
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.int:
|
||||
if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
|
||||
ret = "_int"
|
||||
if t.name == BaseTy.SymInt:
|
||||
ret = "Union[_int, SymInt]"
|
||||
@ -1255,6 +1256,8 @@ def arg_parser_unpack_method(
|
||||
return "scalartypeWithDefault" if has_default_init else "scalartype"
|
||||
elif t.name == BaseTy.Device:
|
||||
return "deviceWithDefault" if has_default_init else "device"
|
||||
elif t.name == BaseTy.DeviceIndex:
|
||||
return "toInt64"
|
||||
elif t.name == BaseTy.int:
|
||||
return "toInt64"
|
||||
elif t.name == BaseTy.SymInt:
|
||||
|
||||
@ -62,6 +62,7 @@ dimnameListT = BaseCppType("at", "DimnameList")
|
||||
dimVectorT = BaseCppType("at", "DimVector")
|
||||
layoutT = BaseCppType("at", "Layout")
|
||||
deviceT = BaseCppType("at", "Device")
|
||||
deviceIndexT = BaseCppType("at", "DeviceIndex")
|
||||
scalarT = BaseCppType("at", "Scalar")
|
||||
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
|
||||
memoryFormatT = BaseCppType("at", "MemoryFormat")
|
||||
@ -111,6 +112,7 @@ BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.DimVector: dimVectorT,
|
||||
BaseTy.Layout: layoutT,
|
||||
BaseTy.Device: deviceT,
|
||||
BaseTy.DeviceIndex: deviceIndexT,
|
||||
BaseTy.Scalar: scalarT,
|
||||
BaseTy.MemoryFormat: memoryFormatT,
|
||||
BaseTy.QScheme: qschemeT,
|
||||
|
||||
@ -1803,6 +1803,7 @@ class BaseTy(Enum):
|
||||
bool = auto()
|
||||
Layout = auto()
|
||||
Device = auto()
|
||||
DeviceIndex = auto()
|
||||
Scalar = auto()
|
||||
MemoryFormat = auto()
|
||||
QScheme = auto()
|
||||
|
||||
Reference in New Issue
Block a user