Revert "Increased compile time max GPUs to 512. Switched to int16_t DeviceIndex. (#119639)"

This reverts commit 7c556428c74a79c6d9c272826344a0828d3f66f5.

Reverted https://github.com/pytorch/pytorch/pull/119639 on behalf of https://github.com/kit1980 due to breaking internal builds, see D54286923 ([comment](https://github.com/pytorch/pytorch/pull/119639#issuecomment-1969634480))
This commit is contained in:
PyTorch MergeBot
2024-02-28 18:57:08 +00:00
parent 1c67f6cb26
commit a9d9077f12
23 changed files with 99 additions and 140 deletions

View File

@ -37,8 +37,8 @@ constexpr int checkStaticTypes() {
// Give nice error messages for some of the common error cases. // Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(std::conjunction< static_assert(std::conjunction<
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int16_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>... bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
>::value, "INVALID TYPE: Only int16_t, int64_t and bool are supported as an integral argument type"); >::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
static_assert(std::conjunction< static_assert(std::conjunction<
bool_t<!std::is_same<Types, float>::value>... bool_t<!std::is_same<Types, float>::value>...
>::value, "INVALID TYPE: float is not supported as an argument type, use double instead"); >::value, "INVALID TYPE: float is not supported as an argument type, use double instead");

View File

@ -43,7 +43,8 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar); ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
ss << extra_args_types; ss << extra_args_types;
ss << vec_size; ss << vec_size;
ss << dev_idx; // DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround
ss << static_cast<int>(dev_idx);
const std::string cache_key = ss.str(); const std::string cache_key = ss.str();
static std::mutex _jiterator_mutex; static std::mutex _jiterator_mutex;

View File

@ -252,17 +252,10 @@ using IndicesT = std::vector<size_t>;
using nested_optional_tensorvec_t = using nested_optional_tensorvec_t =
std::vector<std::vector<c10::optional<at::Tensor>>>; std::vector<std::vector<c10::optional<at::Tensor>>>;
using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>; using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
using FlatMap = std::unordered_map<
// Warning: Do not use ParamsHash for keys with potentially uninitialized DeviceDtypeKey,
// padding bytes! TensorsAndIndicesT,
struct _DeviceDtypeHasher { ParamsHash<DeviceDtypeKey>>;
std::size_t operator()(const DeviceDtypeKey& k) const noexcept {
return std::hash<at::Device>{}(k.first) ^
std::hash<at::ScalarType>{}(k.second);
}
};
using FlatMap =
std::unordered_map<DeviceDtypeKey, TensorsAndIndicesT, _DeviceDtypeHasher>;
inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
const nested_optional_tensorvec_t& nested_tensorlist, const nested_optional_tensorvec_t& nested_tensorlist,

View File

@ -10,8 +10,6 @@ namespace at::native {
// FowlerNollVo hash function // FowlerNollVo hash function
// see // see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
// WARNING: This hash function will produce unexpected results for `Params` with uninitialized padding values, as the
// padding is also part of the hash. Use with caution.
template <typename Params> template <typename Params>
struct ParamsHash { struct ParamsHash {
// Params must be a POD because we read out its memory // Params must be a POD because we read out its memory

View File

@ -125,29 +125,19 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'"); TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
if (!device_index_str.empty()) { try {
// If the user passed an index in the device string, check if it is a valid if (!device_index_str.empty()) {
// int between 0 and c10::Device::MAX_NUM_DEVICES - 1 inclusively index_ = static_cast<c10::DeviceIndex>(std::stoi(device_index_str));
int full_index = -1;
try {
full_index = std::stoi(device_index_str);
} catch (const std::exception&) {
TORCH_CHECK(
false,
"Could not parse device index '",
device_index_str,
"' in device string '",
device_string,
"'");
} }
} catch (const std::exception&) {
TORCH_CHECK( TORCH_CHECK(
0 <= full_index && full_index < c10::Device::MAX_NUM_DEVICES, false,
"Device index must be between 0 and ", "Could not parse device index '",
c10::Device::MAX_NUM_DEVICES - 1, device_index_str,
" inclusively."); "' in device string '",
index_ = static_cast<c10::DeviceIndex>(full_index); device_string,
"'");
} }
type_ = parse_type(device_name); type_ = parse_type(device_name);
validate(); validate();
} }

View File

@ -16,7 +16,7 @@ namespace c10 {
/// A DeviceIndex is not independently meaningful without knowing /// A DeviceIndex is not independently meaningful without knowing
/// the DeviceType it is associated; try to use Device rather than /// the DeviceType it is associated; try to use Device rather than
/// DeviceIndex directly. /// DeviceIndex directly.
using DeviceIndex = int16_t; using DeviceIndex = int8_t;
/// Represents a compute device on which a tensor is located. A device is /// Represents a compute device on which a tensor is located. A device is
/// uniquely identified by a type, which specifies the type of machine it is /// uniquely identified by a type, which specifies the type of machine it is
@ -29,18 +29,6 @@ using DeviceIndex = int16_t;
/// represents a specific, concrete device, /// represents a specific, concrete device,
/// 2. When the device type is CPU, the device index must be zero. /// 2. When the device type is CPU, the device index must be zero.
struct C10_API Device final { struct C10_API Device final {
/// The maximum number of devices that we recognize (formerly known as
/// C10_COMPILE_TIME_MAX_GPUS). This value cannot be more than 32767 because
/// our DeviceIndex is a int16_t. Note that this does not include the default
/// device index -1, but instead defines the range from 0 to MAX_NUM_DEVICES-1
/// inclusively.
#ifdef FBCODE_CAFFE2
// fbcode depends on this value being 16
static constexpr DeviceIndex MAX_NUM_DEVICES = 16;
#else
static constexpr DeviceIndex MAX_NUM_DEVICES = 512;
#endif
using Type = DeviceType; using Type = DeviceType;
/// Constructs a new `Device` from a `DeviceType` and an optional device /// Constructs a new `Device` from a `DeviceType` and an optional device
@ -72,7 +60,6 @@ struct C10_API Device final {
/// Sets the device index. /// Sets the device index.
void set_index(DeviceIndex index) { void set_index(DeviceIndex index) {
index_ = index; index_ = index;
validate();
} }
/// Returns the type of device this is. /// Returns the type of device this is.
@ -188,10 +175,8 @@ struct C10_API Device final {
// This is safe to do, because backends that use the DeviceIndex // This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device. // have a later check when we actually try to switch to that device.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ >= -1 && index_ < MAX_NUM_DEVICES, index_ >= -1,
"Device index must be between -1 and ", "Device index must be -1 or non-negative, got ",
MAX_NUM_DEVICES - 1,
" inclusively, got ",
static_cast<int>(index_)); static_cast<int>(index_));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0, !is_cpu() || index_ <= 0,
@ -211,7 +196,7 @@ struct hash<c10::Device> {
// Are you here because this static assert failed? Make sure you ensure // Are you here because this static assert failed? Make sure you ensure
// that the bitmasking code below is updated accordingly! // that the bitmasking code below is updated accordingly!
static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit"); static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
static_assert(sizeof(c10::DeviceIndex) == 2, "DeviceIndex is not 16-bit"); static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
// Note [Hazard when concatenating signed integers] // Note [Hazard when concatenating signed integers]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// We must first convert to a same-sized unsigned type, before promoting to // We must first convert to a same-sized unsigned type, before promoting to
@ -224,7 +209,7 @@ struct hash<c10::Device> {
// sake. // sake.
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type())) uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
<< 16 | << 16 |
static_cast<uint32_t>(static_cast<uint16_t>(d.index())); static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
return std::hash<uint32_t>{}(bits); return std::hash<uint32_t>{}(bits);
} }
}; };

View File

@ -3169,7 +3169,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
#if UINTPTR_MAX == 0xFFFFFFFF #if UINTPTR_MAX == 0xFFFFFFFF
// This is a 32-bit system // This is a 32-bit system
static constexpr bool check_sizes() { static constexpr bool check_sizes() {
constexpr size_t tsize = 21 * sizeof(int64_t); constexpr size_t tsize = 20 * sizeof(int64_t);
// clang-format off // clang-format off
are_equal<sizeof(storage_), 4, FieldNameEnum::storage_>(); are_equal<sizeof(storage_), 4, FieldNameEnum::storage_>();
@ -3181,7 +3181,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>(); are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>(); are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>(); are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
are_equal<sizeof(device_opt_), 6, FieldNameEnum::device_opt_>(); are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>(); are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>(); is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
// clang-format on // clang-format on
@ -3206,7 +3206,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>(); are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>(); are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>(); are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
are_equal<sizeof(device_opt_), 6, FieldNameEnum::device_opt_>(); are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>(); are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>(); is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
// clang-format on // clang-format on

View File

@ -99,7 +99,7 @@ DeviceIndex device_count() noexcept {
try { try {
auto result = device_count_impl(/*fail_if_no_driver=*/false); auto result = device_count_impl(/*fail_if_no_driver=*/false);
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
result <= c10::Device::MAX_NUM_DEVICES, result <= std::numeric_limits<DeviceIndex>::max(),
"Too many CUDA devices, DeviceIndex overflowed"); "Too many CUDA devices, DeviceIndex overflowed");
return result; return result;
} catch (const c10::Error& ex) { } catch (const c10::Error& ex) {
@ -118,7 +118,7 @@ DeviceIndex device_count_ensure_non_zero() {
// Zero gpus doesn't produce a warning in `device_count` but we fail here // Zero gpus doesn't produce a warning in `device_count` but we fail here
TORCH_CHECK(count, "No CUDA GPUs are available"); TORCH_CHECK(count, "No CUDA GPUs are available");
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
count <= c10::Device::MAX_NUM_DEVICES, count <= std::numeric_limits<DeviceIndex>::max(),
"Too many CUDA devices, DeviceIndex overflowed"); "Too many CUDA devices, DeviceIndex overflowed");
return static_cast<DeviceIndex>(count); return static_cast<DeviceIndex>(count);
} }
@ -219,7 +219,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
auto err = cudaGetDevice(&tmp_device); auto err = cudaGetDevice(&tmp_device);
if (err == cudaSuccess) { if (err == cudaSuccess) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES, tmp_device >= 0 &&
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
"cudaGetDevice returns invalid device ", "cudaGetDevice returns invalid device ",
tmp_device); tmp_device);
*device = static_cast<DeviceIndex>(tmp_device); *device = static_cast<DeviceIndex>(tmp_device);
@ -269,7 +270,8 @@ DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
int tmp_cur_device = -1; int tmp_cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device)); C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device));
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
tmp_cur_device >= 0 && tmp_cur_device < c10::Device::MAX_NUM_DEVICES, tmp_cur_device >= 0 &&
tmp_cur_device <= std::numeric_limits<DeviceIndex>::max(),
"cudaGetDevice returns invalid device ", "cudaGetDevice returns invalid device ",
tmp_cur_device); tmp_cur_device);
auto cur_device = static_cast<DeviceIndex>(tmp_cur_device); auto cur_device = static_cast<DeviceIndex>(tmp_cur_device);
@ -295,7 +297,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
auto err = cudaGetDevice(&tmp_device); auto err = cudaGetDevice(&tmp_device);
if (err == cudaSuccess) { if (err == cudaSuccess) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES, tmp_device >= 0 &&
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
"cudaGetDevice returns invalid device ", "cudaGetDevice returns invalid device ",
tmp_device); tmp_device);
*device = static_cast<DeviceIndex>(tmp_device); *device = static_cast<DeviceIndex>(tmp_device);

View File

@ -37,3 +37,15 @@
#else #else
#define C10_CUDA_API C10_CUDA_IMPORT #define C10_CUDA_API C10_CUDA_IMPORT
#endif #endif
/**
* The maximum number of GPUs that we recognizes. Increasing this beyond the
* initial limit of 16 broke Caffe2 testing, hence the ifdef guards.
* This value cannot be more than 255 because our DeviceIndex is a uint8_t.
o */
#ifdef FBCODE_CAFFE2
// fbcode depends on this value being 16
#define C10_COMPILE_TIME_MAX_GPUS 16
#else
#define C10_COMPILE_TIME_MAX_GPUS 64
#endif

View File

@ -38,18 +38,18 @@ static int max_stream_priorities;
// the destruction. // the destruction.
#if !defined(USE_ROCM) #if !defined(USE_ROCM)
// CUDA-only: used to initializes the stream pools (once) // CUDA-only: used to initializes the stream pools (once)
static c10::once_flag device_flags[c10::Device::MAX_NUM_DEVICES]; static c10::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
#endif #endif
static std::atomic<uint32_t> static std::atomic<uint32_t>
priority_counters[c10::cuda::max_compile_time_stream_priorities] priority_counters[c10::cuda::max_compile_time_stream_priorities]
[c10::Device::MAX_NUM_DEVICES]; [C10_COMPILE_TIME_MAX_GPUS];
static cudaStream_t streams[c10::cuda::max_compile_time_stream_priorities] static cudaStream_t streams[c10::cuda::max_compile_time_stream_priorities]
[c10::Device::MAX_NUM_DEVICES][kStreamsPerPool]; [C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
#ifdef USE_ROCM #ifdef USE_ROCM
static c10::once_flag static c10::once_flag
stream_flags[c10::cuda::max_compile_time_stream_priorities] stream_flags[c10::cuda::max_compile_time_stream_priorities]
[c10::Device::MAX_NUM_DEVICES][kStreamsPerPool]; [C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
#endif #endif
// Note [HIP Lazy Streams] // Note [HIP Lazy Streams]
@ -168,10 +168,10 @@ static void initGlobalStreamState() {
// Check if the number of GPUs matches the expected compile-time max number // Check if the number of GPUs matches the expected compile-time max number
// of GPUs. // of GPUs.
TORCH_CHECK( TORCH_CHECK(
num_gpus <= c10::Device::MAX_NUM_DEVICES, num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
"Number of CUDA devices on the machine is larger than the compiled " "Number of CUDA devices on the machine is larger than the compiled "
"max number of gpus expected (", "max number of gpus expected (",
c10::Device::MAX_NUM_DEVICES, C10_COMPILE_TIME_MAX_GPUS,
"). Increase that and recompile."); "). Increase that and recompile.");
int leastPriority = -1, greatestPriority = -1; int leastPriority = -1, greatestPriority = -1;
C10_CUDA_CHECK( C10_CUDA_CHECK(

View File

@ -224,8 +224,8 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(
REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp); REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
OPERATOR_SCHEMA(NCCLAllreduce) OPERATOR_SCHEMA(NCCLAllreduce)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES) .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES) .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.CostInferenceFunction(NCCLAllreduceOp::CostInference) .CostInferenceFunction(NCCLAllreduceOp::CostInference)
.TensorInferenceFunction(NCCLAllreduceOp::ShapeInference) .TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
.IdenticalTypeAndShape() .IdenticalTypeAndShape()
@ -236,8 +236,8 @@ SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);
REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp); REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
OPERATOR_SCHEMA(NCCLBroadcast) OPERATOR_SCHEMA(NCCLBroadcast)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES) .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES) .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.IdenticalTypeAndShape() .IdenticalTypeAndShape()
.InputsCanCrossDevices() .InputsCanCrossDevices()
.EnforceOneToOneInplace() .EnforceOneToOneInplace()
@ -247,7 +247,7 @@ SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);
REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp); REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
OPERATOR_SCHEMA(NCCLReduce) OPERATOR_SCHEMA(NCCLReduce)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES) .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1) .NumOutputs(1)
.IdenticalTypeAndShapeOfInput(0) .IdenticalTypeAndShapeOfInput(0)
.InputsCanCrossDevices() .InputsCanCrossDevices()
@ -257,16 +257,16 @@ SHOULD_NOT_DO_GRADIENT(NCCLReduce);
REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp); REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
OPERATOR_SCHEMA(NCCLAllGather) OPERATOR_SCHEMA(NCCLAllGather)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES) .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES) .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.InputsCanCrossDevices() .InputsCanCrossDevices()
.DeviceInferenceFunction(ncclOpDevInfer); .DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLAllGather); SHOULD_NOT_DO_GRADIENT(NCCLAllGather);
REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp); REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
OPERATOR_SCHEMA(NCCLReduceScatter) OPERATOR_SCHEMA(NCCLReduceScatter)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES) .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES) .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.InputsCanCrossDevices() .InputsCanCrossDevices()
.DeviceInferenceFunction(ncclOpDevInfer); .DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter); SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);

View File

@ -178,8 +178,8 @@ static std::unordered_map<void*, uint8_t> g_cuda_device_affiliation;
// Data structures for optional memory tracking. Access to these structures // Data structures for optional memory tracking. Access to these structures
// is guarded by the CUDAContext::mutex. // is guarded by the CUDAContext::mutex.
static std::unordered_map<void*, long> g_size_map; static std::unordered_map<void*, long> g_size_map;
static std::vector<long> g_total_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0); static std::vector<long> g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
static std::vector<long> g_max_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0); static std::vector<long> g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
static long g_total_mem = 0; static long g_total_mem = 0;
static long g_last_rep = 0; static long g_last_rep = 0;
@ -208,10 +208,10 @@ static void Caffe2InitializeCuda() {
// of GPUs. // of GPUs.
CAFFE_ENFORCE_LE( CAFFE_ENFORCE_LE(
NumCudaDevices(), NumCudaDevices(),
c10::Device::MAX_NUM_DEVICES, C10_COMPILE_TIME_MAX_GPUS,
"Number of CUDA devices on the machine is larger than the compiled " "Number of CUDA devices on the machine is larger than the compiled "
"max number of gpus expected (", "max number of gpus expected (",
c10::Device::MAX_NUM_DEVICES, C10_COMPILE_TIME_MAX_GPUS,
"). Increase that and recompile."); "). Increase that and recompile.");
for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) { for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) {

View File

@ -58,7 +58,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
private: private:
ThreadLocalCUDAObjects() { ThreadLocalCUDAObjects() {
for (DeviceIndex i = 0; i < c10::Device::MAX_NUM_DEVICES; ++i) { for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) {
cuda_streams_[i] = vector<c10::cuda::CUDAStream>(); cuda_streams_[i] = vector<c10::cuda::CUDAStream>();
} }
} }
@ -164,7 +164,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
// WARNING: mapping from logical stream ID to c10::cuda::CUDAStream // WARNING: mapping from logical stream ID to c10::cuda::CUDAStream
// is NOT bijective; multiple logical stream IDs may map to the // is NOT bijective; multiple logical stream IDs may map to the
// same underlying stream ID. // same underlying stream ID.
vector<c10::cuda::CUDAStream> cuda_streams_[c10::Device::MAX_NUM_DEVICES]; vector<c10::cuda::CUDAStream> cuda_streams_[C10_COMPILE_TIME_MAX_GPUS];
std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_; std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_;
#ifdef CAFFE2_USE_CUDNN #ifdef CAFFE2_USE_CUDNN
std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_; std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_;

View File

@ -188,7 +188,7 @@ class CuDNNWrapper {
using PerGPUCuDNNStates = std::array< using PerGPUCuDNNStates = std::array<
std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>, std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
c10::Device::MAX_NUM_DEVICES>; C10_COMPILE_TIME_MAX_GPUS>;
static PerGPUCuDNNStates& cudnn_states(); static PerGPUCuDNNStates& cudnn_states();
C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper); C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper);

View File

@ -155,7 +155,7 @@ class MIOPENWrapper
using PerGPUMIOPENStates = std::array< using PerGPUMIOPENStates = std::array<
std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>, std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
c10::Device::MAX_NUM_DEVICES>; C10_COMPILE_TIME_MAX_GPUS>;
static PerGPUMIOPENStates& miopen_states(); static PerGPUMIOPENStates& miopen_states();
C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper); C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper);

View File

@ -3440,21 +3440,6 @@ def foo(x):
else: else:
cu.define(full) cu.define(full)
def test_int16_device_index(self):
# This used to fail after the switch from int8 to int16 DeviceIndex as the ArgumentInfo struct hardcoded
# the bit width. Thus, the default device (-1) wrapped around to 255.
# See https://github.com/pytorch/pytorch/issues/115331
tensor = torch.tensor([1.])
code_template = """
def fn(x):
return x.device
"""
cu = torch.jit.CompilationUnit()
cu.define(code_template)
res = cu.fn(tensor)
self.assertEqual(tensor.device, res)
def test_namedtuple_python(self): def test_namedtuple_python(self):
global MyTuple, MyMod # see [local resolution in python] global MyTuple, MyMod # see [local resolution in python]
MyTuple = namedtuple('MyTuple', ['a']) MyTuple = namedtuple('MyTuple', ['a'])

View File

@ -1017,22 +1017,6 @@ class TestDeviceUtils(TestCase):
tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r) tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r)
) )
def test_int16_device_index(self):
# Test if index does not wrap around when larger than int8
large_index = 500
x = torch.device('meta', large_index)
self.assertEqual(x.index, large_index)
def test_raise_on_device_index_out_of_bounds(self):
# Tests if an error is raised when the device index is out of bounds
index_larger_than_max = 100000
error_msg_regex = "^Device index must be.*"
# Explicit index
with self.assertRaisesRegex(RuntimeError, error_msg_regex):
x = torch.device('meta', index=index_larger_than_max)
# Index in device string
with self.assertRaisesRegex(RuntimeError, error_msg_regex):
x = torch.device(f'meta:{index_larger_than_max}')
instantiate_device_type_tests(TestDeviceUtils, globals()) instantiate_device_type_tests(TestDeviceUtils, globals())

View File

@ -31,7 +31,10 @@ PyObject* THPDevice_repr(THPDevice* self) {
std::ostringstream oss; std::ostringstream oss;
oss << "device(type=\'" << self->device.type() << "\'"; oss << "device(type=\'" << self->device.type() << "\'";
if (self->device.has_index()) { if (self->device.has_index()) {
oss << ", index=" << self->device.index(); // `self->device.index()` returns uint8_t which is treated as ascii while
// printing, hence casting it to uint16_t.
// https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
oss << ", index=" << static_cast<uint16_t>(self->device.index());
} }
oss << ")"; oss << ")";
return THPUtils_packString(oss.str().c_str()); return THPUtils_packString(oss.str().c_str());
@ -74,11 +77,7 @@ PyObject* THPDevice_pynew(
device_index = r.toInt64(1); device_index = r.toInt64(1);
// -1 is allowed in ATen/C++, to mean the default device, but not in // -1 is allowed in ATen/C++, to mean the default device, but not in
// Python. // Python.
TORCH_CHECK( TORCH_CHECK(device_index >= 0, "Device index must not be negative");
device_index >= 0 && device_index < c10::Device::MAX_NUM_DEVICES,
"Device index must be between 0 and ",
c10::Device::MAX_NUM_DEVICES - 1,
" inclusively.");
} }
at::Device device( at::Device device(
as_device.type(), static_cast<c10::DeviceIndex>(device_index)); as_device.type(), static_cast<c10::DeviceIndex>(device_index));

View File

@ -2028,10 +2028,23 @@ Call this whenever a new thread is created in order to propagate values from
// torch/csrc/pybind.h` would solve this but it caused segmentation fault in // torch/csrc/pybind.h` would solve this but it caused segmentation fault in
// my environment. // my environment.
using _DeviceDtypeKey = std::pair<at::Device, std::string>; using _DeviceDtypeKey = std::pair<at::Device, std::string>;
// Custom hasher is necessary to make unordered_map compilable for Windows
// debug targets. As `at::native::ParamsHash` only works on structs with
// standard layout, but std::string isn't one in Visual C++ debug builds,
// which one can easily verify by running something like:
// #define _DEBUG
// #include <type_traits>
// #include <string>
// static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
// If above condition is not met, VC++ raises a very cryptic compilation
// error. See
// https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
// more detail
struct _DeviceDtypeHasher { struct _DeviceDtypeHasher {
std::size_t operator()(const _DeviceDtypeKey& k) const noexcept { std::size_t operator()(const _DeviceDtypeKey& k) const noexcept {
return std::hash<at::Device>{}(k.first) ^ static at::native::ParamsHash<at::Device> device_hasher;
std::hash<std::string>{}(k.second); static std::hash<std::string> string_hasher;
return device_hasher(k.first) ^ string_hasher(k.second);
} }
}; };
using _FlatMap = std::unordered_map< using _FlatMap = std::unordered_map<

View File

@ -1,12 +1,11 @@
#pragma once #pragma once
#include <c10/core/Device.h>
#include <c10/cuda/CUDAMacros.h> #include <c10/cuda/CUDAMacros.h>
#include <bitset> #include <bitset>
#include <cstddef> #include <cstddef>
namespace torch { namespace torch {
using device_set = std::bitset<c10::Device::MAX_NUM_DEVICES>; using device_set = std::bitset<C10_COMPILE_TIME_MAX_GPUS>;
} // namespace torch } // namespace torch

View File

@ -2,7 +2,6 @@
#include <ATen/core/jit_type.h> #include <ATen/core/jit_type.h>
#include <ATen/core/stack.h> #include <ATen/core/stack.h>
#include <c10/core/Device.h>
#include <c10/util/hash.h> #include <c10/util/hash.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <torch/csrc/Export.h> #include <torch/csrc/Export.h>
@ -57,10 +56,12 @@ struct ArgumentInfo {
private: private:
unsigned defined_ : 1; unsigned defined_ : 1;
unsigned requires_grad_ : 1; unsigned requires_grad_ : 1;
unsigned : 5;
unsigned dim_ : 8; unsigned dim_ : 8;
signed device_ : sizeof(c10::DeviceIndex) * 8; unsigned device_ : 8;
unsigned type_ : 8; unsigned type_ : 8;
unsigned dev_type_ : 16; unsigned dev_type_ : 16;
unsigned : 16;
}; };
static_assert( static_assert(
@ -68,7 +69,7 @@ static_assert(
"ArgumentInfo is to be a POD struct"); "ArgumentInfo is to be a POD struct");
static_assert( static_assert(
sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type), sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
"ArgumentInfo is expected to be a 64-bit struct"); "ArgumentInfo is expected to be a 32-bit struct");
struct ArgumentSpec { struct ArgumentSpec {
ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs) ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs)
@ -222,8 +223,8 @@ struct CompleteArgumentInfoPOD {
unsigned type : 8; // scalar type unsigned type : 8; // scalar type
unsigned defined : 1; unsigned defined : 1;
unsigned requires_grad : 1; unsigned requires_grad : 1;
signed dev_type : sizeof(c10::DeviceType) * 8; signed device : 14;
signed device : sizeof(c10::DeviceIndex) * 8; unsigned dev_type : 16;
unsigned unsigned
total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's
// tensor_info() array. total_dims is the total number of // tensor_info() array. total_dims is the total number of

View File

@ -807,11 +807,7 @@ inline at::Device toDevice(PyObject* obj) {
} }
if (THPUtils_checkLong(obj)) { if (THPUtils_checkLong(obj)) {
const auto device_index = THPUtils_unpackLong(obj); const auto device_index = THPUtils_unpackLong(obj);
TORCH_CHECK( TORCH_CHECK(device_index >= 0, "Device index must not be negative");
device_index >= 0 && device_index < c10::Device::MAX_NUM_DEVICES,
"Device index must be between 0 and ",
c10::Device::MAX_NUM_DEVICES - 1,
" inclusively.");
if (c10::is_privateuse1_backend_registered()) { if (c10::is_privateuse1_backend_registered()) {
return at::Device( return at::Device(
c10::DeviceType::PrivateUse1, c10::DeviceType::PrivateUse1,

View File

@ -5258,12 +5258,12 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon):
options = self.rpc_backend_options options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size) dst = worker_name((self.rank + 1) % self.world_size)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Device index must .*" RuntimeError, "Device index must not be negative"
): ):
options.set_device_map(dst, {-1: 0}) options.set_device_map(dst, {-1: 0})
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Device index must .*" RuntimeError, "Device index must not be negative"
): ):
options.set_device_map(dst, {0: -1}) options.set_device_map(dst, {0: -1})