mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Increased compile time max GPUs to 512. Switched to int16_t DeviceIndex. (#119639)"
This reverts commit 7c556428c74a79c6d9c272826344a0828d3f66f5. Reverted https://github.com/pytorch/pytorch/pull/119639 on behalf of https://github.com/kit1980 due to breaking internal builds, see D54286923 ([comment](https://github.com/pytorch/pytorch/pull/119639#issuecomment-1969634480))
This commit is contained in:
@ -37,8 +37,8 @@ constexpr int checkStaticTypes() {
|
||||
// Give nice error messages for some of the common error cases.
|
||||
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
|
||||
static_assert(std::conjunction<
|
||||
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int16_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
|
||||
>::value, "INVALID TYPE: Only int16_t, int64_t and bool are supported as an integral argument type");
|
||||
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
|
||||
>::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
|
||||
static_assert(std::conjunction<
|
||||
bool_t<!std::is_same<Types, float>::value>...
|
||||
>::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
|
||||
|
@ -43,7 +43,8 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
ss << extra_args_types;
|
||||
ss << vec_size;
|
||||
ss << dev_idx;
|
||||
// DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround
|
||||
ss << static_cast<int>(dev_idx);
|
||||
const std::string cache_key = ss.str();
|
||||
|
||||
static std::mutex _jiterator_mutex;
|
||||
|
@ -252,17 +252,10 @@ using IndicesT = std::vector<size_t>;
|
||||
using nested_optional_tensorvec_t =
|
||||
std::vector<std::vector<c10::optional<at::Tensor>>>;
|
||||
using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
|
||||
|
||||
// Warning: Do not use ParamsHash for keys with potentially uninitialized
|
||||
// padding bytes!
|
||||
struct _DeviceDtypeHasher {
|
||||
std::size_t operator()(const DeviceDtypeKey& k) const noexcept {
|
||||
return std::hash<at::Device>{}(k.first) ^
|
||||
std::hash<at::ScalarType>{}(k.second);
|
||||
}
|
||||
};
|
||||
using FlatMap =
|
||||
std::unordered_map<DeviceDtypeKey, TensorsAndIndicesT, _DeviceDtypeHasher>;
|
||||
using FlatMap = std::unordered_map<
|
||||
DeviceDtypeKey,
|
||||
TensorsAndIndicesT,
|
||||
ParamsHash<DeviceDtypeKey>>;
|
||||
|
||||
inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
|
||||
const nested_optional_tensorvec_t& nested_tensorlist,
|
||||
|
@ -10,8 +10,6 @@ namespace at::native {
|
||||
// Fowler–Noll–Vo hash function
|
||||
// see
|
||||
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
||||
// WARNING: This hash function will produce unexpected results for `Params` with uninitialized padding values, as the
|
||||
// padding is also part of the hash. Use with caution.
|
||||
template <typename Params>
|
||||
struct ParamsHash {
|
||||
// Params must be a POD because we read out its memory
|
||||
|
@ -125,29 +125,19 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
|
||||
|
||||
TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
|
||||
|
||||
if (!device_index_str.empty()) {
|
||||
// If the user passed an index in the device string, check if it is a valid
|
||||
// int between 0 and c10::Device::MAX_NUM_DEVICES - 1 inclusively
|
||||
int full_index = -1;
|
||||
try {
|
||||
full_index = std::stoi(device_index_str);
|
||||
} catch (const std::exception&) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Could not parse device index '",
|
||||
device_index_str,
|
||||
"' in device string '",
|
||||
device_string,
|
||||
"'");
|
||||
try {
|
||||
if (!device_index_str.empty()) {
|
||||
index_ = static_cast<c10::DeviceIndex>(std::stoi(device_index_str));
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
TORCH_CHECK(
|
||||
0 <= full_index && full_index < c10::Device::MAX_NUM_DEVICES,
|
||||
"Device index must be between 0 and ",
|
||||
c10::Device::MAX_NUM_DEVICES - 1,
|
||||
" inclusively.");
|
||||
index_ = static_cast<c10::DeviceIndex>(full_index);
|
||||
false,
|
||||
"Could not parse device index '",
|
||||
device_index_str,
|
||||
"' in device string '",
|
||||
device_string,
|
||||
"'");
|
||||
}
|
||||
|
||||
type_ = parse_type(device_name);
|
||||
validate();
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ namespace c10 {
|
||||
/// A DeviceIndex is not independently meaningful without knowing
|
||||
/// the DeviceType it is associated; try to use Device rather than
|
||||
/// DeviceIndex directly.
|
||||
using DeviceIndex = int16_t;
|
||||
using DeviceIndex = int8_t;
|
||||
|
||||
/// Represents a compute device on which a tensor is located. A device is
|
||||
/// uniquely identified by a type, which specifies the type of machine it is
|
||||
@ -29,18 +29,6 @@ using DeviceIndex = int16_t;
|
||||
/// represents a specific, concrete device,
|
||||
/// 2. When the device type is CPU, the device index must be zero.
|
||||
struct C10_API Device final {
|
||||
/// The maximum number of devices that we recognize (formerly known as
|
||||
/// C10_COMPILE_TIME_MAX_GPUS). This value cannot be more than 32767 because
|
||||
/// our DeviceIndex is a int16_t. Note that this does not include the default
|
||||
/// device index -1, but instead defines the range from 0 to MAX_NUM_DEVICES-1
|
||||
/// inclusively.
|
||||
#ifdef FBCODE_CAFFE2
|
||||
// fbcode depends on this value being 16
|
||||
static constexpr DeviceIndex MAX_NUM_DEVICES = 16;
|
||||
#else
|
||||
static constexpr DeviceIndex MAX_NUM_DEVICES = 512;
|
||||
#endif
|
||||
|
||||
using Type = DeviceType;
|
||||
|
||||
/// Constructs a new `Device` from a `DeviceType` and an optional device
|
||||
@ -72,7 +60,6 @@ struct C10_API Device final {
|
||||
/// Sets the device index.
|
||||
void set_index(DeviceIndex index) {
|
||||
index_ = index;
|
||||
validate();
|
||||
}
|
||||
|
||||
/// Returns the type of device this is.
|
||||
@ -188,10 +175,8 @@ struct C10_API Device final {
|
||||
// This is safe to do, because backends that use the DeviceIndex
|
||||
// have a later check when we actually try to switch to that device.
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
index_ >= -1 && index_ < MAX_NUM_DEVICES,
|
||||
"Device index must be between -1 and ",
|
||||
MAX_NUM_DEVICES - 1,
|
||||
" inclusively, got ",
|
||||
index_ >= -1,
|
||||
"Device index must be -1 or non-negative, got ",
|
||||
static_cast<int>(index_));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!is_cpu() || index_ <= 0,
|
||||
@ -211,7 +196,7 @@ struct hash<c10::Device> {
|
||||
// Are you here because this static assert failed? Make sure you ensure
|
||||
// that the bitmasking code below is updated accordingly!
|
||||
static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
|
||||
static_assert(sizeof(c10::DeviceIndex) == 2, "DeviceIndex is not 16-bit");
|
||||
static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
|
||||
// Note [Hazard when concatenating signed integers]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// We must first convert to a same-sized unsigned type, before promoting to
|
||||
@ -224,7 +209,7 @@ struct hash<c10::Device> {
|
||||
// sake.
|
||||
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
|
||||
<< 16 |
|
||||
static_cast<uint32_t>(static_cast<uint16_t>(d.index()));
|
||||
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
|
||||
return std::hash<uint32_t>{}(bits);
|
||||
}
|
||||
};
|
||||
|
@ -3169,7 +3169,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
|
||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||
// This is a 32-bit system
|
||||
static constexpr bool check_sizes() {
|
||||
constexpr size_t tsize = 21 * sizeof(int64_t);
|
||||
constexpr size_t tsize = 20 * sizeof(int64_t);
|
||||
|
||||
// clang-format off
|
||||
are_equal<sizeof(storage_), 4, FieldNameEnum::storage_>();
|
||||
@ -3181,7 +3181,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
|
||||
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
|
||||
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
|
||||
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
|
||||
are_equal<sizeof(device_opt_), 6, FieldNameEnum::device_opt_>();
|
||||
are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
|
||||
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
|
||||
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
|
||||
// clang-format on
|
||||
@ -3206,7 +3206,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
|
||||
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
|
||||
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
|
||||
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
|
||||
are_equal<sizeof(device_opt_), 6, FieldNameEnum::device_opt_>();
|
||||
are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
|
||||
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
|
||||
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
|
||||
// clang-format on
|
||||
|
@ -99,7 +99,7 @@ DeviceIndex device_count() noexcept {
|
||||
try {
|
||||
auto result = device_count_impl(/*fail_if_no_driver=*/false);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result <= c10::Device::MAX_NUM_DEVICES,
|
||||
result <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"Too many CUDA devices, DeviceIndex overflowed");
|
||||
return result;
|
||||
} catch (const c10::Error& ex) {
|
||||
@ -118,7 +118,7 @@ DeviceIndex device_count_ensure_non_zero() {
|
||||
// Zero gpus doesn't produce a warning in `device_count` but we fail here
|
||||
TORCH_CHECK(count, "No CUDA GPUs are available");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
count <= c10::Device::MAX_NUM_DEVICES,
|
||||
count <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"Too many CUDA devices, DeviceIndex overflowed");
|
||||
return static_cast<DeviceIndex>(count);
|
||||
}
|
||||
@ -219,7 +219,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
|
||||
auto err = cudaGetDevice(&tmp_device);
|
||||
if (err == cudaSuccess) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES,
|
||||
tmp_device >= 0 &&
|
||||
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"cudaGetDevice returns invalid device ",
|
||||
tmp_device);
|
||||
*device = static_cast<DeviceIndex>(tmp_device);
|
||||
@ -269,7 +270,8 @@ DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
|
||||
int tmp_cur_device = -1;
|
||||
C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tmp_cur_device >= 0 && tmp_cur_device < c10::Device::MAX_NUM_DEVICES,
|
||||
tmp_cur_device >= 0 &&
|
||||
tmp_cur_device <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"cudaGetDevice returns invalid device ",
|
||||
tmp_cur_device);
|
||||
auto cur_device = static_cast<DeviceIndex>(tmp_cur_device);
|
||||
@ -295,7 +297,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
|
||||
auto err = cudaGetDevice(&tmp_device);
|
||||
if (err == cudaSuccess) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES,
|
||||
tmp_device >= 0 &&
|
||||
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"cudaGetDevice returns invalid device ",
|
||||
tmp_device);
|
||||
*device = static_cast<DeviceIndex>(tmp_device);
|
||||
|
@ -37,3 +37,15 @@
|
||||
#else
|
||||
#define C10_CUDA_API C10_CUDA_IMPORT
|
||||
#endif
|
||||
|
||||
/**
|
||||
* The maximum number of GPUs that we recognizes. Increasing this beyond the
|
||||
* initial limit of 16 broke Caffe2 testing, hence the ifdef guards.
|
||||
* This value cannot be more than 255 because our DeviceIndex is a uint8_t.
|
||||
o */
|
||||
#ifdef FBCODE_CAFFE2
|
||||
// fbcode depends on this value being 16
|
||||
#define C10_COMPILE_TIME_MAX_GPUS 16
|
||||
#else
|
||||
#define C10_COMPILE_TIME_MAX_GPUS 64
|
||||
#endif
|
||||
|
@ -38,18 +38,18 @@ static int max_stream_priorities;
|
||||
// the destruction.
|
||||
#if !defined(USE_ROCM)
|
||||
// CUDA-only: used to initializes the stream pools (once)
|
||||
static c10::once_flag device_flags[c10::Device::MAX_NUM_DEVICES];
|
||||
static c10::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
|
||||
#endif
|
||||
static std::atomic<uint32_t>
|
||||
priority_counters[c10::cuda::max_compile_time_stream_priorities]
|
||||
[c10::Device::MAX_NUM_DEVICES];
|
||||
[C10_COMPILE_TIME_MAX_GPUS];
|
||||
|
||||
static cudaStream_t streams[c10::cuda::max_compile_time_stream_priorities]
|
||||
[c10::Device::MAX_NUM_DEVICES][kStreamsPerPool];
|
||||
[C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
|
||||
#ifdef USE_ROCM
|
||||
static c10::once_flag
|
||||
stream_flags[c10::cuda::max_compile_time_stream_priorities]
|
||||
[c10::Device::MAX_NUM_DEVICES][kStreamsPerPool];
|
||||
[C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
|
||||
#endif
|
||||
|
||||
// Note [HIP Lazy Streams]
|
||||
@ -168,10 +168,10 @@ static void initGlobalStreamState() {
|
||||
// Check if the number of GPUs matches the expected compile-time max number
|
||||
// of GPUs.
|
||||
TORCH_CHECK(
|
||||
num_gpus <= c10::Device::MAX_NUM_DEVICES,
|
||||
num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
|
||||
"Number of CUDA devices on the machine is larger than the compiled "
|
||||
"max number of gpus expected (",
|
||||
c10::Device::MAX_NUM_DEVICES,
|
||||
C10_COMPILE_TIME_MAX_GPUS,
|
||||
"). Increase that and recompile.");
|
||||
int leastPriority = -1, greatestPriority = -1;
|
||||
C10_CUDA_CHECK(
|
||||
|
@ -224,8 +224,8 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(
|
||||
|
||||
REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
|
||||
OPERATOR_SCHEMA(NCCLAllreduce)
|
||||
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.CostInferenceFunction(NCCLAllreduceOp::CostInference)
|
||||
.TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
|
||||
.IdenticalTypeAndShape()
|
||||
@ -236,8 +236,8 @@ SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);
|
||||
|
||||
REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
|
||||
OPERATOR_SCHEMA(NCCLBroadcast)
|
||||
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.IdenticalTypeAndShape()
|
||||
.InputsCanCrossDevices()
|
||||
.EnforceOneToOneInplace()
|
||||
@ -247,7 +247,7 @@ SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);
|
||||
|
||||
REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
|
||||
OPERATOR_SCHEMA(NCCLReduce)
|
||||
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.NumOutputs(1)
|
||||
.IdenticalTypeAndShapeOfInput(0)
|
||||
.InputsCanCrossDevices()
|
||||
@ -257,16 +257,16 @@ SHOULD_NOT_DO_GRADIENT(NCCLReduce);
|
||||
|
||||
REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
|
||||
OPERATOR_SCHEMA(NCCLAllGather)
|
||||
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.InputsCanCrossDevices()
|
||||
.DeviceInferenceFunction(ncclOpDevInfer);
|
||||
SHOULD_NOT_DO_GRADIENT(NCCLAllGather);
|
||||
|
||||
REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
|
||||
OPERATOR_SCHEMA(NCCLReduceScatter)
|
||||
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
|
||||
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
|
||||
.InputsCanCrossDevices()
|
||||
.DeviceInferenceFunction(ncclOpDevInfer);
|
||||
SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);
|
||||
|
@ -178,8 +178,8 @@ static std::unordered_map<void*, uint8_t> g_cuda_device_affiliation;
|
||||
// Data structures for optional memory tracking. Access to these structures
|
||||
// is guarded by the CUDAContext::mutex.
|
||||
static std::unordered_map<void*, long> g_size_map;
|
||||
static std::vector<long> g_total_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0);
|
||||
static std::vector<long> g_max_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0);
|
||||
static std::vector<long> g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
|
||||
static std::vector<long> g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
|
||||
|
||||
static long g_total_mem = 0;
|
||||
static long g_last_rep = 0;
|
||||
@ -208,10 +208,10 @@ static void Caffe2InitializeCuda() {
|
||||
// of GPUs.
|
||||
CAFFE_ENFORCE_LE(
|
||||
NumCudaDevices(),
|
||||
c10::Device::MAX_NUM_DEVICES,
|
||||
C10_COMPILE_TIME_MAX_GPUS,
|
||||
"Number of CUDA devices on the machine is larger than the compiled "
|
||||
"max number of gpus expected (",
|
||||
c10::Device::MAX_NUM_DEVICES,
|
||||
C10_COMPILE_TIME_MAX_GPUS,
|
||||
"). Increase that and recompile.");
|
||||
|
||||
for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) {
|
||||
|
@ -58,7 +58,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
|
||||
|
||||
private:
|
||||
ThreadLocalCUDAObjects() {
|
||||
for (DeviceIndex i = 0; i < c10::Device::MAX_NUM_DEVICES; ++i) {
|
||||
for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) {
|
||||
cuda_streams_[i] = vector<c10::cuda::CUDAStream>();
|
||||
}
|
||||
}
|
||||
@ -164,7 +164,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
|
||||
// WARNING: mapping from logical stream ID to c10::cuda::CUDAStream
|
||||
// is NOT bijective; multiple logical stream IDs may map to the
|
||||
// same underlying stream ID.
|
||||
vector<c10::cuda::CUDAStream> cuda_streams_[c10::Device::MAX_NUM_DEVICES];
|
||||
vector<c10::cuda::CUDAStream> cuda_streams_[C10_COMPILE_TIME_MAX_GPUS];
|
||||
std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_;
|
||||
#ifdef CAFFE2_USE_CUDNN
|
||||
std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_;
|
||||
|
@ -188,7 +188,7 @@ class CuDNNWrapper {
|
||||
|
||||
using PerGPUCuDNNStates = std::array<
|
||||
std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
|
||||
c10::Device::MAX_NUM_DEVICES>;
|
||||
C10_COMPILE_TIME_MAX_GPUS>;
|
||||
static PerGPUCuDNNStates& cudnn_states();
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper);
|
||||
|
@ -155,7 +155,7 @@ class MIOPENWrapper
|
||||
|
||||
using PerGPUMIOPENStates = std::array<
|
||||
std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
|
||||
c10::Device::MAX_NUM_DEVICES>;
|
||||
C10_COMPILE_TIME_MAX_GPUS>;
|
||||
static PerGPUMIOPENStates& miopen_states();
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper);
|
||||
|
@ -3440,21 +3440,6 @@ def foo(x):
|
||||
else:
|
||||
cu.define(full)
|
||||
|
||||
def test_int16_device_index(self):
|
||||
# This used to fail after the switch from int8 to int16 DeviceIndex as the ArgumentInfo struct hardcoded
|
||||
# the bit width. Thus, the default device (-1) wrapped around to 255.
|
||||
# See https://github.com/pytorch/pytorch/issues/115331
|
||||
tensor = torch.tensor([1.])
|
||||
code_template = """
|
||||
def fn(x):
|
||||
return x.device
|
||||
"""
|
||||
cu = torch.jit.CompilationUnit()
|
||||
cu.define(code_template)
|
||||
res = cu.fn(tensor)
|
||||
self.assertEqual(tensor.device, res)
|
||||
|
||||
|
||||
def test_namedtuple_python(self):
|
||||
global MyTuple, MyMod # see [local resolution in python]
|
||||
MyTuple = namedtuple('MyTuple', ['a'])
|
||||
|
@ -1017,22 +1017,6 @@ class TestDeviceUtils(TestCase):
|
||||
tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r)
|
||||
)
|
||||
|
||||
def test_int16_device_index(self):
|
||||
# Test if index does not wrap around when larger than int8
|
||||
large_index = 500
|
||||
x = torch.device('meta', large_index)
|
||||
self.assertEqual(x.index, large_index)
|
||||
|
||||
def test_raise_on_device_index_out_of_bounds(self):
|
||||
# Tests if an error is raised when the device index is out of bounds
|
||||
index_larger_than_max = 100000
|
||||
error_msg_regex = "^Device index must be.*"
|
||||
# Explicit index
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg_regex):
|
||||
x = torch.device('meta', index=index_larger_than_max)
|
||||
# Index in device string
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg_regex):
|
||||
x = torch.device(f'meta:{index_larger_than_max}')
|
||||
|
||||
instantiate_device_type_tests(TestDeviceUtils, globals())
|
||||
|
||||
|
@ -31,7 +31,10 @@ PyObject* THPDevice_repr(THPDevice* self) {
|
||||
std::ostringstream oss;
|
||||
oss << "device(type=\'" << self->device.type() << "\'";
|
||||
if (self->device.has_index()) {
|
||||
oss << ", index=" << self->device.index();
|
||||
// `self->device.index()` returns uint8_t which is treated as ascii while
|
||||
// printing, hence casting it to uint16_t.
|
||||
// https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
|
||||
oss << ", index=" << static_cast<uint16_t>(self->device.index());
|
||||
}
|
||||
oss << ")";
|
||||
return THPUtils_packString(oss.str().c_str());
|
||||
@ -74,11 +77,7 @@ PyObject* THPDevice_pynew(
|
||||
device_index = r.toInt64(1);
|
||||
// -1 is allowed in ATen/C++, to mean the default device, but not in
|
||||
// Python.
|
||||
TORCH_CHECK(
|
||||
device_index >= 0 && device_index < c10::Device::MAX_NUM_DEVICES,
|
||||
"Device index must be between 0 and ",
|
||||
c10::Device::MAX_NUM_DEVICES - 1,
|
||||
" inclusively.");
|
||||
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
|
||||
}
|
||||
at::Device device(
|
||||
as_device.type(), static_cast<c10::DeviceIndex>(device_index));
|
||||
|
@ -2028,10 +2028,23 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
// torch/csrc/pybind.h` would solve this but it caused segmentation fault in
|
||||
// my environment.
|
||||
using _DeviceDtypeKey = std::pair<at::Device, std::string>;
|
||||
// Custom hasher is necessary to make unordered_map compilable for Windows
|
||||
// debug targets. As `at::native::ParamsHash` only works on structs with
|
||||
// standard layout, but std::string isn't one in Visual C++ debug builds,
|
||||
// which one can easily verify by running something like:
|
||||
// #define _DEBUG
|
||||
// #include <type_traits>
|
||||
// #include <string>
|
||||
// static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
|
||||
// If above condition is not met, VC++ raises a very cryptic compilation
|
||||
// error. See
|
||||
// https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
|
||||
// more detail
|
||||
struct _DeviceDtypeHasher {
|
||||
std::size_t operator()(const _DeviceDtypeKey& k) const noexcept {
|
||||
return std::hash<at::Device>{}(k.first) ^
|
||||
std::hash<std::string>{}(k.second);
|
||||
static at::native::ParamsHash<at::Device> device_hasher;
|
||||
static std::hash<std::string> string_hasher;
|
||||
return device_hasher(k.first) ^ string_hasher(k.second);
|
||||
}
|
||||
};
|
||||
using _FlatMap = std::unordered_map<
|
||||
|
@ -1,12 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <bitset>
|
||||
#include <cstddef>
|
||||
|
||||
namespace torch {
|
||||
|
||||
using device_set = std::bitset<c10::Device::MAX_NUM_DEVICES>;
|
||||
using device_set = std::bitset<C10_COMPILE_TIME_MAX_GPUS>;
|
||||
|
||||
} // namespace torch
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/hash.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
@ -57,10 +56,12 @@ struct ArgumentInfo {
|
||||
private:
|
||||
unsigned defined_ : 1;
|
||||
unsigned requires_grad_ : 1;
|
||||
unsigned : 5;
|
||||
unsigned dim_ : 8;
|
||||
signed device_ : sizeof(c10::DeviceIndex) * 8;
|
||||
unsigned device_ : 8;
|
||||
unsigned type_ : 8;
|
||||
unsigned dev_type_ : 16;
|
||||
unsigned : 16;
|
||||
};
|
||||
|
||||
static_assert(
|
||||
@ -68,7 +69,7 @@ static_assert(
|
||||
"ArgumentInfo is to be a POD struct");
|
||||
static_assert(
|
||||
sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
|
||||
"ArgumentInfo is expected to be a 64-bit struct");
|
||||
"ArgumentInfo is expected to be a 32-bit struct");
|
||||
|
||||
struct ArgumentSpec {
|
||||
ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs)
|
||||
@ -222,8 +223,8 @@ struct CompleteArgumentInfoPOD {
|
||||
unsigned type : 8; // scalar type
|
||||
unsigned defined : 1;
|
||||
unsigned requires_grad : 1;
|
||||
signed dev_type : sizeof(c10::DeviceType) * 8;
|
||||
signed device : sizeof(c10::DeviceIndex) * 8;
|
||||
signed device : 14;
|
||||
unsigned dev_type : 16;
|
||||
unsigned
|
||||
total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's
|
||||
// tensor_info() array. total_dims is the total number of
|
||||
|
@ -807,11 +807,7 @@ inline at::Device toDevice(PyObject* obj) {
|
||||
}
|
||||
if (THPUtils_checkLong(obj)) {
|
||||
const auto device_index = THPUtils_unpackLong(obj);
|
||||
TORCH_CHECK(
|
||||
device_index >= 0 && device_index < c10::Device::MAX_NUM_DEVICES,
|
||||
"Device index must be between 0 and ",
|
||||
c10::Device::MAX_NUM_DEVICES - 1,
|
||||
" inclusively.");
|
||||
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
|
||||
if (c10::is_privateuse1_backend_registered()) {
|
||||
return at::Device(
|
||||
c10::DeviceType::PrivateUse1,
|
||||
|
@ -5258,12 +5258,12 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon):
|
||||
options = self.rpc_backend_options
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Device index must .*"
|
||||
RuntimeError, "Device index must not be negative"
|
||||
):
|
||||
options.set_device_map(dst, {-1: 0})
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Device index must .*"
|
||||
RuntimeError, "Device index must not be negative"
|
||||
):
|
||||
options.set_device_map(dst, {0: -1})
|
||||
|
||||
|
Reference in New Issue
Block a user