mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Increased compile time max GPUs to 512. Switched to int16_t DeviceIndex. (#119639)
Fixes #115331. This PR increases the number of valid GPU devices to 512 (from 64) in order to future-proof PyTorch for providers that offer [single nodes with a large device count](https://www.tensorwave.com/). Until now, `DeviceIndex` was an `int8_t`, thus multiple changes were necessary: - `DeviceIndex` changed to `int16_t`. Updated consumers that assume it to be an `int8_t`. - Updated bounds checking for `torch.device()` in the Python frontend. Right now, we allow funny things like `torch.device('cpu', 200).index == -56`, which is undefined behavior. I inserted some checks to only allow values between 0 and `c10::Device::MAX_NUM_DEVICES - 1`. - Updated the `ArgumentInfo` struct as it hardcodes the device index as 8 bit field [^1]. Might be a breaking change, not sure if users rely on this. - Introduced `c10::Device::MAX_NUM_DEVICES` as a replacement for the old `C10_COMPILE_TIME_MAX_GPUS` [^1]: This field was unsigned, so I guess this has also been undef behavior the whole time? Our default device index is -1, so this always wrapped around to 255 when written to the `ArgumentInfo` struct. When I switched the `DeviceIndex` to `int16_t`, it actually stayed 255 after unpacking from `ArgumentInfo` again, as the `DeviceIndex` was now wide enough that it didn't wrap back to -1. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119639 Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/huydhn
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							cbbc309cae
						
					
				
				
					commit
					7c556428c7
				
			| @ -37,8 +37,8 @@ constexpr int checkStaticTypes() { | ||||
|  // Give nice error messages for some of the common error cases. | ||||
|  // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT | ||||
|  static_assert(std::conjunction< | ||||
|      bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>... | ||||
|    >::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type"); | ||||
|      bool_t<!std::is_integral<Types>::value || std::is_same<Types, int16_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>... | ||||
|    >::value, "INVALID TYPE: Only int16_t, int64_t and bool are supported as an integral argument type"); | ||||
|  static_assert(std::conjunction< | ||||
|      bool_t<!std::is_same<Types, float>::value>... | ||||
|    >::value, "INVALID TYPE: float is not supported as an argument type, use double instead"); | ||||
|  | ||||
| @ -43,8 +43,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic( | ||||
|   ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar); | ||||
|   ss << extra_args_types; | ||||
|   ss << vec_size; | ||||
| // DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround | ||||
|   ss << static_cast<int>(dev_idx); | ||||
|   ss << dev_idx; | ||||
|   const std::string cache_key = ss.str(); | ||||
|  | ||||
|   static std::mutex _jiterator_mutex; | ||||
|  | ||||
| @ -252,10 +252,17 @@ using IndicesT = std::vector<size_t>; | ||||
| using nested_optional_tensorvec_t = | ||||
|     std::vector<std::vector<c10::optional<at::Tensor>>>; | ||||
| using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>; | ||||
| using FlatMap = std::unordered_map< | ||||
|     DeviceDtypeKey, | ||||
|     TensorsAndIndicesT, | ||||
|     ParamsHash<DeviceDtypeKey>>; | ||||
|  | ||||
| // Warning: Do not use ParamsHash for keys with potentially uninitialized | ||||
| // padding bytes! | ||||
| struct _DeviceDtypeHasher { | ||||
|   std::size_t operator()(const DeviceDtypeKey& k) const noexcept { | ||||
|     return std::hash<at::Device>{}(k.first) ^ | ||||
|         std::hash<at::ScalarType>{}(k.second); | ||||
|   } | ||||
| }; | ||||
| using FlatMap = | ||||
|     std::unordered_map<DeviceDtypeKey, TensorsAndIndicesT, _DeviceDtypeHasher>; | ||||
|  | ||||
| inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( | ||||
|     const nested_optional_tensorvec_t& nested_tensorlist, | ||||
|  | ||||
| @ -10,6 +10,8 @@ namespace at::native { | ||||
| // Fowler–Noll–Vo hash function | ||||
| // see | ||||
| // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function | ||||
| // WARNING: This hash function will produce unexpected results for `Params` with uninitialized padding values, as the | ||||
| // padding is also part of the hash. Use with caution. | ||||
| template <typename Params> | ||||
| struct ParamsHash { | ||||
|   // Params must be a POD because we read out its memory | ||||
|  | ||||
| @ -125,19 +125,29 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) { | ||||
|  | ||||
|   TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'"); | ||||
|  | ||||
|   try { | ||||
|     if (!device_index_str.empty()) { | ||||
|       index_ = static_cast<c10::DeviceIndex>(std::stoi(device_index_str)); | ||||
|   if (!device_index_str.empty()) { | ||||
|     // If the user passed an index in the device string, check if it is a valid | ||||
|     // int between 0 and c10::Device::MAX_NUM_DEVICES - 1 inclusively | ||||
|     int full_index = -1; | ||||
|     try { | ||||
|       full_index = std::stoi(device_index_str); | ||||
|     } catch (const std::exception&) { | ||||
|       TORCH_CHECK( | ||||
|           false, | ||||
|           "Could not parse device index '", | ||||
|           device_index_str, | ||||
|           "' in device string '", | ||||
|           device_string, | ||||
|           "'"); | ||||
|     } | ||||
|   } catch (const std::exception&) { | ||||
|     TORCH_CHECK( | ||||
|         false, | ||||
|         "Could not parse device index '", | ||||
|         device_index_str, | ||||
|         "' in device string '", | ||||
|         device_string, | ||||
|         "'"); | ||||
|         0 <= full_index && full_index < c10::Device::MAX_NUM_DEVICES, | ||||
|         "Device index must be between 0 and ", | ||||
|         c10::Device::MAX_NUM_DEVICES - 1, | ||||
|         " inclusively."); | ||||
|     index_ = static_cast<c10::DeviceIndex>(full_index); | ||||
|   } | ||||
|  | ||||
|   type_ = parse_type(device_name); | ||||
|   validate(); | ||||
| } | ||||
|  | ||||
| @ -16,7 +16,7 @@ namespace c10 { | ||||
| /// A DeviceIndex is not independently meaningful without knowing | ||||
| /// the DeviceType it is associated; try to use Device rather than | ||||
| /// DeviceIndex directly. | ||||
| using DeviceIndex = int8_t; | ||||
| using DeviceIndex = int16_t; | ||||
|  | ||||
| /// Represents a compute device on which a tensor is located. A device is | ||||
| /// uniquely identified by a type, which specifies the type of machine it is | ||||
| @ -29,6 +29,18 @@ using DeviceIndex = int8_t; | ||||
| /// represents a specific, concrete device, | ||||
| /// 2. When the device type is CPU, the device index must be zero. | ||||
| struct C10_API Device final { | ||||
|   /// The maximum number of devices that we recognize (formerly known as | ||||
|   /// C10_COMPILE_TIME_MAX_GPUS). This value cannot be more than 32767 because | ||||
|   /// our DeviceIndex is a int16_t. Note that this does not include the default | ||||
|   /// device index -1, but instead defines the range from 0 to MAX_NUM_DEVICES-1 | ||||
|   /// inclusively. | ||||
| #ifdef FBCODE_CAFFE2 | ||||
|   // fbcode depends on this value being 16 | ||||
|   static constexpr DeviceIndex MAX_NUM_DEVICES = 16; | ||||
| #else | ||||
|   static constexpr DeviceIndex MAX_NUM_DEVICES = 512; | ||||
| #endif | ||||
|  | ||||
|   using Type = DeviceType; | ||||
|  | ||||
|   /// Constructs a new `Device` from a `DeviceType` and an optional device | ||||
| @ -60,6 +72,7 @@ struct C10_API Device final { | ||||
|   /// Sets the device index. | ||||
|   void set_index(DeviceIndex index) { | ||||
|     index_ = index; | ||||
|     validate(); | ||||
|   } | ||||
|  | ||||
|   /// Returns the type of device this is. | ||||
| @ -175,8 +188,10 @@ struct C10_API Device final { | ||||
|     // This is safe to do, because backends that use the DeviceIndex | ||||
|     // have a later check when we actually try to switch to that device. | ||||
|     TORCH_INTERNAL_ASSERT_DEBUG_ONLY( | ||||
|         index_ >= -1, | ||||
|         "Device index must be -1 or non-negative, got ", | ||||
|         index_ >= -1 && index_ < MAX_NUM_DEVICES, | ||||
|         "Device index must be between -1 and ", | ||||
|         MAX_NUM_DEVICES - 1, | ||||
|         " inclusively, got ", | ||||
|         static_cast<int>(index_)); | ||||
|     TORCH_INTERNAL_ASSERT_DEBUG_ONLY( | ||||
|         !is_cpu() || index_ <= 0, | ||||
| @ -196,7 +211,7 @@ struct hash<c10::Device> { | ||||
|     // Are you here because this static assert failed?  Make sure you ensure | ||||
|     // that the bitmasking code below is updated accordingly! | ||||
|     static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit"); | ||||
|     static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); | ||||
|     static_assert(sizeof(c10::DeviceIndex) == 2, "DeviceIndex is not 16-bit"); | ||||
|     // Note [Hazard when concatenating signed integers] | ||||
|     // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|     // We must first convert to a same-sized unsigned type, before promoting to | ||||
| @ -209,7 +224,7 @@ struct hash<c10::Device> { | ||||
|     // sake. | ||||
|     uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type())) | ||||
|             << 16 | | ||||
|         static_cast<uint32_t>(static_cast<uint8_t>(d.index())); | ||||
|         static_cast<uint32_t>(static_cast<uint16_t>(d.index())); | ||||
|     return std::hash<uint32_t>{}(bits); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| @ -3169,7 +3169,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { | ||||
| #if UINTPTR_MAX == 0xFFFFFFFF | ||||
|   // This is a 32-bit system | ||||
|   static constexpr bool check_sizes() { | ||||
|     constexpr size_t tsize = 20 * sizeof(int64_t); | ||||
|     constexpr size_t tsize = 21 * sizeof(int64_t); | ||||
|  | ||||
|     // clang-format off | ||||
|     are_equal<sizeof(storage_),            4,  FieldNameEnum::storage_>(); | ||||
| @ -3181,7 +3181,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { | ||||
|     are_equal<sizeof(storage_offset_),     8,  FieldNameEnum::storage_offset_>(); | ||||
|     are_equal<sizeof(numel_),              8,  FieldNameEnum::numel_>(); | ||||
|     are_equal<sizeof(data_type_),          2,  FieldNameEnum::data_type_>(); | ||||
|     are_equal<sizeof(device_opt_),         3,  FieldNameEnum::device_opt_>(); | ||||
|     are_equal<sizeof(device_opt_),         6,  FieldNameEnum::device_opt_>(); | ||||
|     are_equal<sizeof(key_set_),            8,  FieldNameEnum::key_set_>(); | ||||
|     is_le<sizeof(TensorImpl),          tsize,  FieldNameEnum::TOTAL_SIZE>(); | ||||
|     // clang-format on | ||||
| @ -3206,7 +3206,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { | ||||
|     are_equal<sizeof(storage_offset_),     8,  FieldNameEnum::storage_offset_>(); | ||||
|     are_equal<sizeof(numel_),              8,  FieldNameEnum::numel_>(); | ||||
|     are_equal<sizeof(data_type_),          2,  FieldNameEnum::data_type_>(); | ||||
|     are_equal<sizeof(device_opt_),         3,  FieldNameEnum::device_opt_>(); | ||||
|     are_equal<sizeof(device_opt_),         6,  FieldNameEnum::device_opt_>(); | ||||
|     are_equal<sizeof(key_set_),            8,  FieldNameEnum::key_set_>(); | ||||
|     is_le<sizeof(TensorImpl),          tsize,  FieldNameEnum::TOTAL_SIZE>(); | ||||
|     // clang-format on | ||||
|  | ||||
| @ -99,7 +99,7 @@ DeviceIndex device_count() noexcept { | ||||
|     try { | ||||
|       auto result = device_count_impl(/*fail_if_no_driver=*/false); | ||||
|       TORCH_INTERNAL_ASSERT( | ||||
|           result <= std::numeric_limits<DeviceIndex>::max(), | ||||
|           result <= c10::Device::MAX_NUM_DEVICES, | ||||
|           "Too many CUDA devices, DeviceIndex overflowed"); | ||||
|       return result; | ||||
|     } catch (const c10::Error& ex) { | ||||
| @ -118,7 +118,7 @@ DeviceIndex device_count_ensure_non_zero() { | ||||
|   // Zero gpus doesn't produce a warning in `device_count` but we fail here | ||||
|   TORCH_CHECK(count, "No CUDA GPUs are available"); | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       count <= std::numeric_limits<DeviceIndex>::max(), | ||||
|       count <= c10::Device::MAX_NUM_DEVICES, | ||||
|       "Too many CUDA devices, DeviceIndex overflowed"); | ||||
|   return static_cast<DeviceIndex>(count); | ||||
| } | ||||
| @ -219,8 +219,7 @@ cudaError_t GetDevice(DeviceIndex* device) { | ||||
|   auto err = cudaGetDevice(&tmp_device); | ||||
|   if (err == cudaSuccess) { | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         tmp_device >= 0 && | ||||
|             tmp_device <= std::numeric_limits<DeviceIndex>::max(), | ||||
|         tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES, | ||||
|         "cudaGetDevice returns invalid device ", | ||||
|         tmp_device); | ||||
|     *device = static_cast<DeviceIndex>(tmp_device); | ||||
| @ -270,8 +269,7 @@ DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) { | ||||
|   int tmp_cur_device = -1; | ||||
|   C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device)); | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       tmp_cur_device >= 0 && | ||||
|           tmp_cur_device <= std::numeric_limits<DeviceIndex>::max(), | ||||
|       tmp_cur_device >= 0 && tmp_cur_device < c10::Device::MAX_NUM_DEVICES, | ||||
|       "cudaGetDevice returns invalid device ", | ||||
|       tmp_cur_device); | ||||
|   auto cur_device = static_cast<DeviceIndex>(tmp_cur_device); | ||||
| @ -297,8 +295,7 @@ cudaError_t GetDevice(DeviceIndex* device) { | ||||
|   auto err = cudaGetDevice(&tmp_device); | ||||
|   if (err == cudaSuccess) { | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         tmp_device >= 0 && | ||||
|             tmp_device <= std::numeric_limits<DeviceIndex>::max(), | ||||
|         tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES, | ||||
|         "cudaGetDevice returns invalid device ", | ||||
|         tmp_device); | ||||
|     *device = static_cast<DeviceIndex>(tmp_device); | ||||
|  | ||||
| @ -37,15 +37,3 @@ | ||||
| #else | ||||
| #define C10_CUDA_API C10_CUDA_IMPORT | ||||
| #endif | ||||
|  | ||||
| /** | ||||
|  * The maximum number of GPUs that we recognizes. Increasing this beyond the | ||||
|  * initial limit of 16 broke Caffe2 testing, hence the ifdef guards. | ||||
|  * This value cannot be more than 255 because our DeviceIndex is a uint8_t. | ||||
| o */ | ||||
| #ifdef FBCODE_CAFFE2 | ||||
| // fbcode depends on this value being 16 | ||||
| #define C10_COMPILE_TIME_MAX_GPUS 16 | ||||
| #else | ||||
| #define C10_COMPILE_TIME_MAX_GPUS 64 | ||||
| #endif | ||||
|  | ||||
| @ -38,18 +38,18 @@ static int max_stream_priorities; | ||||
| // the destruction. | ||||
| #if !defined(USE_ROCM) | ||||
| // CUDA-only: used to initializes the stream pools (once) | ||||
| static c10::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS]; | ||||
| static c10::once_flag device_flags[c10::Device::MAX_NUM_DEVICES]; | ||||
| #endif | ||||
| static std::atomic<uint32_t> | ||||
|     priority_counters[c10::cuda::max_compile_time_stream_priorities] | ||||
|                      [C10_COMPILE_TIME_MAX_GPUS]; | ||||
|                      [c10::Device::MAX_NUM_DEVICES]; | ||||
|  | ||||
| static cudaStream_t streams[c10::cuda::max_compile_time_stream_priorities] | ||||
|                            [C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool]; | ||||
|                            [c10::Device::MAX_NUM_DEVICES][kStreamsPerPool]; | ||||
| #ifdef USE_ROCM | ||||
| static c10::once_flag | ||||
|     stream_flags[c10::cuda::max_compile_time_stream_priorities] | ||||
|                 [C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool]; | ||||
|                 [c10::Device::MAX_NUM_DEVICES][kStreamsPerPool]; | ||||
| #endif | ||||
|  | ||||
| // Note [HIP Lazy Streams] | ||||
| @ -168,10 +168,10 @@ static void initGlobalStreamState() { | ||||
|   // Check if the number of GPUs matches the expected compile-time max number | ||||
|   // of GPUs. | ||||
|   TORCH_CHECK( | ||||
|       num_gpus <= C10_COMPILE_TIME_MAX_GPUS, | ||||
|       num_gpus <= c10::Device::MAX_NUM_DEVICES, | ||||
|       "Number of CUDA devices on the machine is larger than the compiled " | ||||
|       "max number of gpus expected (", | ||||
|       C10_COMPILE_TIME_MAX_GPUS, | ||||
|       c10::Device::MAX_NUM_DEVICES, | ||||
|       "). Increase that and recompile."); | ||||
|   int leastPriority = -1, greatestPriority = -1; | ||||
|   C10_CUDA_CHECK( | ||||
|  | ||||
| @ -224,8 +224,8 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer( | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp); | ||||
| OPERATOR_SCHEMA(NCCLAllreduce) | ||||
|     .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumInputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .NumOutputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .CostInferenceFunction(NCCLAllreduceOp::CostInference) | ||||
|     .TensorInferenceFunction(NCCLAllreduceOp::ShapeInference) | ||||
|     .IdenticalTypeAndShape() | ||||
| @ -236,8 +236,8 @@ SHOULD_NOT_DO_GRADIENT(NCCLAllreduce); | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp); | ||||
| OPERATOR_SCHEMA(NCCLBroadcast) | ||||
|     .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumInputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .NumOutputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .IdenticalTypeAndShape() | ||||
|     .InputsCanCrossDevices() | ||||
|     .EnforceOneToOneInplace() | ||||
| @ -247,7 +247,7 @@ SHOULD_NOT_DO_GRADIENT(NCCLBroadcast); | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp); | ||||
| OPERATOR_SCHEMA(NCCLReduce) | ||||
|     .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumInputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .NumOutputs(1) | ||||
|     .IdenticalTypeAndShapeOfInput(0) | ||||
|     .InputsCanCrossDevices() | ||||
| @ -257,16 +257,16 @@ SHOULD_NOT_DO_GRADIENT(NCCLReduce); | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp); | ||||
| OPERATOR_SCHEMA(NCCLAllGather) | ||||
|     .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumInputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .NumOutputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .InputsCanCrossDevices() | ||||
|     .DeviceInferenceFunction(ncclOpDevInfer); | ||||
| SHOULD_NOT_DO_GRADIENT(NCCLAllGather); | ||||
|  | ||||
| REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp); | ||||
| OPERATOR_SCHEMA(NCCLReduceScatter) | ||||
|     .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS) | ||||
|     .NumInputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .NumOutputs(1, c10::Device::MAX_NUM_DEVICES) | ||||
|     .InputsCanCrossDevices() | ||||
|     .DeviceInferenceFunction(ncclOpDevInfer); | ||||
| SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter); | ||||
|  | ||||
| @ -178,8 +178,8 @@ static std::unordered_map<void*, uint8_t> g_cuda_device_affiliation; | ||||
| // Data structures for optional memory tracking. Access to these structures | ||||
| // is guarded by the CUDAContext::mutex. | ||||
| static std::unordered_map<void*, long> g_size_map; | ||||
| static std::vector<long> g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0); | ||||
| static std::vector<long> g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0); | ||||
| static std::vector<long> g_total_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0); | ||||
| static std::vector<long> g_max_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0); | ||||
|  | ||||
| static long g_total_mem = 0; | ||||
| static long g_last_rep = 0; | ||||
| @ -208,10 +208,10 @@ static void Caffe2InitializeCuda() { | ||||
|   // of GPUs. | ||||
|   CAFFE_ENFORCE_LE( | ||||
|       NumCudaDevices(), | ||||
|       C10_COMPILE_TIME_MAX_GPUS, | ||||
|       c10::Device::MAX_NUM_DEVICES, | ||||
|       "Number of CUDA devices on the machine is larger than the compiled " | ||||
|       "max number of gpus expected (", | ||||
|       C10_COMPILE_TIME_MAX_GPUS, | ||||
|       c10::Device::MAX_NUM_DEVICES, | ||||
|       "). Increase that and recompile."); | ||||
|  | ||||
|   for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) { | ||||
|  | ||||
| @ -58,7 +58,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects { | ||||
|  | ||||
|  private: | ||||
|   ThreadLocalCUDAObjects() { | ||||
|     for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) { | ||||
|     for (DeviceIndex i = 0; i < c10::Device::MAX_NUM_DEVICES; ++i) { | ||||
|       cuda_streams_[i] = vector<c10::cuda::CUDAStream>(); | ||||
|     } | ||||
|   } | ||||
| @ -164,7 +164,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects { | ||||
|   // WARNING: mapping from logical stream ID to c10::cuda::CUDAStream | ||||
|   // is NOT bijective; multiple logical stream IDs may map to the | ||||
|   // same underlying stream ID. | ||||
|   vector<c10::cuda::CUDAStream> cuda_streams_[C10_COMPILE_TIME_MAX_GPUS]; | ||||
|   vector<c10::cuda::CUDAStream> cuda_streams_[c10::Device::MAX_NUM_DEVICES]; | ||||
|   std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_; | ||||
| #ifdef CAFFE2_USE_CUDNN | ||||
|   std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_; | ||||
|  | ||||
| @ -188,7 +188,7 @@ class CuDNNWrapper { | ||||
|  | ||||
|   using PerGPUCuDNNStates = std::array< | ||||
|       std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>, | ||||
|       C10_COMPILE_TIME_MAX_GPUS>; | ||||
|       c10::Device::MAX_NUM_DEVICES>; | ||||
|   static PerGPUCuDNNStates& cudnn_states(); | ||||
|  | ||||
|   C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper); | ||||
|  | ||||
| @ -155,7 +155,7 @@ class MIOPENWrapper | ||||
|  | ||||
|     using PerGPUMIOPENStates = std::array< | ||||
|         std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>, | ||||
|         C10_COMPILE_TIME_MAX_GPUS>; | ||||
|         c10::Device::MAX_NUM_DEVICES>; | ||||
|     static PerGPUMIOPENStates& miopen_states(); | ||||
|  | ||||
|     C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper); | ||||
|  | ||||
| @ -3440,6 +3440,21 @@ def foo(x): | ||||
|             else: | ||||
|                 cu.define(full) | ||||
|  | ||||
|     def test_int16_device_index(self): | ||||
|         # This used to fail after the switch from int8 to int16 DeviceIndex as the ArgumentInfo struct hardcoded | ||||
|         # the bit width. Thus, the default device (-1) wrapped around to 255. | ||||
|         # See https://github.com/pytorch/pytorch/issues/115331 | ||||
|         tensor = torch.tensor([1.]) | ||||
|         code_template = """ | ||||
|         def fn(x): | ||||
|             return x.device | ||||
|         """ | ||||
|         cu = torch.jit.CompilationUnit() | ||||
|         cu.define(code_template) | ||||
|         res = cu.fn(tensor) | ||||
|         self.assertEqual(tensor.device, res) | ||||
|  | ||||
|  | ||||
|     def test_namedtuple_python(self): | ||||
|         global MyTuple, MyMod  # see [local resolution in python] | ||||
|         MyTuple = namedtuple('MyTuple', ['a']) | ||||
|  | ||||
| @ -1017,6 +1017,22 @@ class TestDeviceUtils(TestCase): | ||||
|                 tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r) | ||||
|             ) | ||||
|  | ||||
|     def test_int16_device_index(self): | ||||
|         # Test if index does not wrap around when larger than int8 | ||||
|         large_index = 500 | ||||
|         x = torch.device('meta', large_index) | ||||
|         self.assertEqual(x.index, large_index) | ||||
|  | ||||
|     def test_raise_on_device_index_out_of_bounds(self): | ||||
|         # Tests if an error is raised when the device index is out of bounds | ||||
|         index_larger_than_max = 100000 | ||||
|         error_msg_regex = "^Device index must be.*" | ||||
|         # Explicit index | ||||
|         with self.assertRaisesRegex(RuntimeError, error_msg_regex): | ||||
|             x = torch.device('meta', index=index_larger_than_max) | ||||
|         # Index in device string | ||||
|         with self.assertRaisesRegex(RuntimeError, error_msg_regex): | ||||
|             x = torch.device(f'meta:{index_larger_than_max}') | ||||
|  | ||||
| instantiate_device_type_tests(TestDeviceUtils, globals()) | ||||
|  | ||||
|  | ||||
| @ -31,10 +31,7 @@ PyObject* THPDevice_repr(THPDevice* self) { | ||||
|   std::ostringstream oss; | ||||
|   oss << "device(type=\'" << self->device.type() << "\'"; | ||||
|   if (self->device.has_index()) { | ||||
|     // `self->device.index()` returns uint8_t which is treated as ascii while | ||||
|     // printing, hence casting it to uint16_t. | ||||
|     // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout | ||||
|     oss << ", index=" << static_cast<uint16_t>(self->device.index()); | ||||
|     oss << ", index=" << self->device.index(); | ||||
|   } | ||||
|   oss << ")"; | ||||
|   return THPUtils_packString(oss.str().c_str()); | ||||
| @ -77,7 +74,11 @@ PyObject* THPDevice_pynew( | ||||
|       device_index = r.toInt64(1); | ||||
|       // -1 is allowed in ATen/C++, to mean the default device, but not in | ||||
|       // Python. | ||||
|       TORCH_CHECK(device_index >= 0, "Device index must not be negative"); | ||||
|       TORCH_CHECK( | ||||
|           device_index >= 0 && device_index < c10::Device::MAX_NUM_DEVICES, | ||||
|           "Device index must be between 0 and ", | ||||
|           c10::Device::MAX_NUM_DEVICES - 1, | ||||
|           " inclusively."); | ||||
|     } | ||||
|     at::Device device( | ||||
|         as_device.type(), static_cast<c10::DeviceIndex>(device_index)); | ||||
|  | ||||
| @ -2028,23 +2028,10 @@ Call this whenever a new thread is created in order to propagate values from | ||||
|   // torch/csrc/pybind.h` would solve this but it caused segmentation fault in | ||||
|   // my environment. | ||||
|   using _DeviceDtypeKey = std::pair<at::Device, std::string>; | ||||
|   // Custom hasher is necessary to make unordered_map compilable for Windows | ||||
|   // debug targets. As `at::native::ParamsHash` only works on structs with | ||||
|   // standard layout, but std::string isn't one in Visual C++ debug builds, | ||||
|   // which one can easily verify by running something like: | ||||
|   //   #define _DEBUG | ||||
|   //   #include <type_traits> | ||||
|   //   #include <string> | ||||
|   //   static_assert(std::is_standard_layout_v<std::string>, "Oh noes"); | ||||
|   // If above condition is not met, VC++ raises a very cryptic compilation | ||||
|   // error. See | ||||
|   // https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for | ||||
|   // more detail | ||||
|   struct _DeviceDtypeHasher { | ||||
|     std::size_t operator()(const _DeviceDtypeKey& k) const noexcept { | ||||
|       static at::native::ParamsHash<at::Device> device_hasher; | ||||
|       static std::hash<std::string> string_hasher; | ||||
|       return device_hasher(k.first) ^ string_hasher(k.second); | ||||
|       return std::hash<at::Device>{}(k.first) ^ | ||||
|           std::hash<std::string>{}(k.second); | ||||
|     } | ||||
|   }; | ||||
|   using _FlatMap = std::unordered_map< | ||||
|  | ||||
| @ -1,11 +1,12 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/core/Device.h> | ||||
| #include <c10/cuda/CUDAMacros.h> | ||||
| #include <bitset> | ||||
| #include <cstddef> | ||||
|  | ||||
| namespace torch { | ||||
|  | ||||
| using device_set = std::bitset<C10_COMPILE_TIME_MAX_GPUS>; | ||||
| using device_set = std::bitset<c10::Device::MAX_NUM_DEVICES>; | ||||
|  | ||||
| } // namespace torch | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <ATen/core/stack.h> | ||||
| #include <c10/core/Device.h> | ||||
| #include <c10/util/hash.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <torch/csrc/Export.h> | ||||
| @ -56,12 +57,10 @@ struct ArgumentInfo { | ||||
|  private: | ||||
|   unsigned defined_ : 1; | ||||
|   unsigned requires_grad_ : 1; | ||||
|   unsigned : 5; | ||||
|   unsigned dim_ : 8; | ||||
|   unsigned device_ : 8; | ||||
|   signed device_ : sizeof(c10::DeviceIndex) * 8; | ||||
|   unsigned type_ : 8; | ||||
|   unsigned dev_type_ : 16; | ||||
|   unsigned : 16; | ||||
| }; | ||||
|  | ||||
| static_assert( | ||||
| @ -69,7 +68,7 @@ static_assert( | ||||
|     "ArgumentInfo is to be a POD struct"); | ||||
| static_assert( | ||||
|     sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type), | ||||
|     "ArgumentInfo is expected to be a 32-bit struct"); | ||||
|     "ArgumentInfo is expected to be a 64-bit struct"); | ||||
|  | ||||
| struct ArgumentSpec { | ||||
|   ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs) | ||||
| @ -223,8 +222,8 @@ struct CompleteArgumentInfoPOD { | ||||
|   unsigned type : 8; // scalar type | ||||
|   unsigned defined : 1; | ||||
|   unsigned requires_grad : 1; | ||||
|   signed device : 14; | ||||
|   unsigned dev_type : 16; | ||||
|   signed dev_type : sizeof(c10::DeviceType) * 8; | ||||
|   signed device : sizeof(c10::DeviceIndex) * 8; | ||||
|   unsigned | ||||
|       total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's | ||||
|                        // tensor_info() array. total_dims is the total number of | ||||
|  | ||||
| @ -807,7 +807,11 @@ inline at::Device toDevice(PyObject* obj) { | ||||
|   } | ||||
|   if (THPUtils_checkLong(obj)) { | ||||
|     const auto device_index = THPUtils_unpackLong(obj); | ||||
|     TORCH_CHECK(device_index >= 0, "Device index must not be negative"); | ||||
|     TORCH_CHECK( | ||||
|         device_index >= 0 && device_index < c10::Device::MAX_NUM_DEVICES, | ||||
|         "Device index must be between 0 and ", | ||||
|         c10::Device::MAX_NUM_DEVICES - 1, | ||||
|         " inclusively."); | ||||
|     if (c10::is_privateuse1_backend_registered()) { | ||||
|       return at::Device( | ||||
|           c10::DeviceType::PrivateUse1, | ||||
|  | ||||
| @ -5258,12 +5258,12 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon): | ||||
|         options = self.rpc_backend_options | ||||
|         dst = worker_name((self.rank + 1) % self.world_size) | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, "Device index must not be negative" | ||||
|             RuntimeError, "Device index must .*" | ||||
|         ): | ||||
|             options.set_device_map(dst, {-1: 0}) | ||||
|  | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, "Device index must not be negative" | ||||
|             RuntimeError, "Device index must .*" | ||||
|         ): | ||||
|             options.set_device_map(dst, {0: -1}) | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user