mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/N] Apply clang-tidy to c10 cuda files (#111137)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/111137 Approved by: https://github.com/zou3519, https://github.com/Skylion007
This commit is contained in:
@ -58,7 +58,7 @@ void reportMemoryUsageToProfiler(
|
||||
}
|
||||
|
||||
void reportOutOfMemoryToProfiler(
|
||||
int64_t alloc_size,
|
||||
size_t alloc_size,
|
||||
size_t total_allocated,
|
||||
size_t total_reserved,
|
||||
Device device) {
|
||||
@ -73,7 +73,7 @@ void reportOutOfMemoryToProfiler(
|
||||
MemoryReportingInfoBase::MemoryReportingInfoBase() = default;
|
||||
|
||||
void MemoryReportingInfoBase::reportOutOfMemory(
|
||||
int64_t /*alloc_size*/,
|
||||
size_t /*alloc_size*/,
|
||||
size_t /*total_allocated*/,
|
||||
size_t /*total_reserved*/,
|
||||
Device /*device*/) {}
|
||||
|
@ -247,7 +247,7 @@ struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
|
||||
Device device) = 0;
|
||||
|
||||
virtual void reportOutOfMemory(
|
||||
int64_t alloc_size,
|
||||
size_t alloc_size,
|
||||
size_t total_allocated,
|
||||
size_t total_reserved,
|
||||
Device device);
|
||||
@ -264,7 +264,7 @@ C10_API void reportMemoryUsageToProfiler(
|
||||
Device device);
|
||||
|
||||
C10_API void reportOutOfMemoryToProfiler(
|
||||
int64_t alloc_size,
|
||||
size_t alloc_size,
|
||||
size_t total_allocated,
|
||||
size_t total_reserved,
|
||||
Device device);
|
||||
|
@ -272,10 +272,7 @@ void ProfiledCPUMemoryReporter::OutOfMemory(size_t nbytes) {
|
||||
}
|
||||
if (profile_memory) {
|
||||
reportOutOfMemoryToProfiler(
|
||||
static_cast<int64_t>(nbytes),
|
||||
allocated,
|
||||
0,
|
||||
c10::Device(c10::DeviceType::CPU));
|
||||
nbytes, allocated, 0, c10::Device(c10::DeviceType::CPU));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,19 +132,24 @@ using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
|
||||
|
||||
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
|
||||
|
||||
void update_stat(Stat& stat, int64_t amount) {
|
||||
void add_amount_to_stat(Stat& stat, size_t amount) {
|
||||
stat.current += amount;
|
||||
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
stat.current >= 0,
|
||||
"Negative tracked stat in CUDA allocator (likely logic error).");
|
||||
|
||||
stat.peak = std::max(stat.current, stat.peak);
|
||||
if (amount > 0) {
|
||||
stat.allocated += amount;
|
||||
}
|
||||
if (amount < 0) {
|
||||
stat.freed += -amount;
|
||||
stat.allocated += amount;
|
||||
}
|
||||
|
||||
void decrease_amount_from_stat(Stat& stat, size_t amount) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
amount <= stat.current,
|
||||
"Negative tracked stat in CUDA allocator (likely logic error).");
|
||||
stat.current -= amount;
|
||||
stat.freed += amount;
|
||||
}
|
||||
void update_stat(Stat& stat, int64_t amount) {
|
||||
if (amount >= 0) {
|
||||
add_amount_to_stat(stat, amount);
|
||||
} else {
|
||||
decrease_amount_from_stat(stat, -amount);
|
||||
}
|
||||
}
|
||||
|
||||
@ -166,13 +171,13 @@ void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
|
||||
}
|
||||
}
|
||||
|
||||
void update_stat_array(
|
||||
void decrease_stat_array(
|
||||
StatArray& stat_array,
|
||||
int64_t amount,
|
||||
size_t amount,
|
||||
const StatTypes& stat_types) {
|
||||
for_each_selected_stat_type(
|
||||
stat_types, [&stat_array, amount](size_t stat_type) {
|
||||
update_stat(stat_array[stat_type], amount);
|
||||
decrease_amount_from_stat(stat_array[stat_type], amount);
|
||||
});
|
||||
}
|
||||
|
||||
@ -190,6 +195,7 @@ struct BlockPool {
|
||||
owner_PrivatePool(private_pool) {}
|
||||
std::set<Block*, Comparison> blocks;
|
||||
std::set<Block*, Comparison> unmapped;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const bool is_small;
|
||||
PrivatePool* owner_PrivatePool;
|
||||
};
|
||||
@ -451,6 +457,7 @@ struct ExpandableSegment {
|
||||
}
|
||||
|
||||
char* ptr() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return (char*)ptr_;
|
||||
}
|
||||
size_t size() const {
|
||||
@ -503,8 +510,8 @@ struct ExpandableSegment {
|
||||
handles_.pop_back();
|
||||
}
|
||||
}
|
||||
void forEachAllocatedRange(std::function<void(size_t, size_t)> fn) {
|
||||
auto start = 0;
|
||||
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
|
||||
size_t start = 0;
|
||||
for (auto i : c10::irange(handles_.size())) {
|
||||
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
|
||||
start = i;
|
||||
@ -815,15 +822,15 @@ static std::string reportProcessMemoryInfo(int device) {
|
||||
|
||||
std::vector<nvmlProcessInfo_v1_t> procs(8);
|
||||
unsigned int size = procs.size();
|
||||
nvmlReturn_t r;
|
||||
nvmlReturn_t r{NVML_ERROR_UNKNOWN};
|
||||
while ((r = DriverAPI::get()->nvmlDeviceGetComputeRunningProcesses_(
|
||||
nvml_device, &size, procs.data())) ==
|
||||
NVML_ERROR_INSUFFICIENT_SIZE) {
|
||||
procs.resize(size);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == r);
|
||||
unsigned int self_pid = getpid();
|
||||
std::stringstream ss;
|
||||
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == r);
|
||||
ss << "";
|
||||
for (auto i : c10::irange(size)) {
|
||||
auto& proc = procs[i];
|
||||
@ -884,7 +891,7 @@ class DeviceCachingAllocator {
|
||||
bool set_fraction = false;
|
||||
|
||||
bool record_history = false;
|
||||
std::atomic<CreateContextFn> context_recorder_;
|
||||
std::atomic<CreateContextFn> context_recorder_{};
|
||||
size_t alloc_trace_next = 0;
|
||||
RecordContext record_context_ = RecordContext::NEVER;
|
||||
size_t alloc_trace_max_entries_ = 1;
|
||||
@ -1065,9 +1072,9 @@ class DeviceCachingAllocator {
|
||||
|
||||
c10::reportOutOfMemoryToProfiler(
|
||||
size,
|
||||
stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
|
||||
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current,
|
||||
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current,
|
||||
c10::Device(c10::DeviceType::CUDA, static_cast<DeviceIndex>(device)));
|
||||
|
||||
@ -1136,11 +1143,11 @@ class DeviceCachingAllocator {
|
||||
|
||||
bool split_remainder = should_split(params.block, params.size());
|
||||
return alloc_found_block(
|
||||
std::move(params), orig_size, std::move(context), split_remainder);
|
||||
params, orig_size, std::move(context), split_remainder);
|
||||
}
|
||||
|
||||
Block* alloc_found_block(
|
||||
AllocParams params,
|
||||
const AllocParams& params,
|
||||
size_t orig_size,
|
||||
std::shared_ptr<GatheredContext> context,
|
||||
bool split_remainder) {
|
||||
@ -1175,28 +1182,24 @@ class DeviceCachingAllocator {
|
||||
|
||||
if (already_split && !block->expandable_segment_) {
|
||||
// An already-split inactive block is being shrunk by size bytes.
|
||||
update_stat_array(
|
||||
stats.inactive_split_bytes,
|
||||
-static_cast<std::int64_t>(block->size),
|
||||
params.stat_types);
|
||||
decrease_stat_array(
|
||||
stats.inactive_split_bytes, block->size, params.stat_types);
|
||||
} else if (!block->expandable_segment_) {
|
||||
// A new split inactive block is being created from a previously unsplit
|
||||
// block, size remaining->size bytes.
|
||||
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
|
||||
update_stat(
|
||||
stats.inactive_split_bytes[stat_type],
|
||||
static_cast<std::int64_t>(remaining->size));
|
||||
update_stat(stats.inactive_split[stat_type], 1);
|
||||
add_amount_to_stat(
|
||||
stats.inactive_split_bytes[stat_type], remaining->size);
|
||||
add_amount_to_stat(stats.inactive_split[stat_type], 1);
|
||||
});
|
||||
}
|
||||
|
||||
} else if (already_split && !block->expandable_segment_) {
|
||||
// An already-split block is becoming active
|
||||
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
|
||||
update_stat(
|
||||
stats.inactive_split_bytes[stat_type],
|
||||
-static_cast<std::int64_t>(block->size));
|
||||
update_stat(stats.inactive_split[stat_type], -1);
|
||||
decrease_amount_from_stat(
|
||||
stats.inactive_split_bytes[stat_type], block->size);
|
||||
decrease_amount_from_stat(stats.inactive_split[stat_type], 1);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1216,24 +1219,19 @@ class DeviceCachingAllocator {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
|
||||
|
||||
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.allocation[stat_type], 1);
|
||||
update_stat(
|
||||
stats.allocated_bytes[stat_type],
|
||||
static_cast<std::int64_t>(block->size));
|
||||
update_stat(stats.active[stat_type], 1);
|
||||
update_stat(
|
||||
stats.active_bytes[stat_type],
|
||||
static_cast<std::int64_t>(block->size));
|
||||
update_stat(
|
||||
stats.requested_bytes[stat_type],
|
||||
static_cast<std::int64_t>(block->requested_size));
|
||||
add_amount_to_stat(stats.allocation[stat_type], 1);
|
||||
add_amount_to_stat(stats.allocated_bytes[stat_type], block->size);
|
||||
add_amount_to_stat(stats.active[stat_type], 1);
|
||||
add_amount_to_stat(stats.active_bytes[stat_type], block->size);
|
||||
add_amount_to_stat(
|
||||
stats.requested_bytes[stat_type], block->requested_size);
|
||||
});
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_allocations, 1);
|
||||
add_amount_to_stat(stats.oversize_allocations, 1);
|
||||
|
||||
c10::reportMemoryUsageToProfiler(
|
||||
block->ptr,
|
||||
block->size,
|
||||
static_cast<int64_t>(block->size),
|
||||
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
c10::Device(c10::DeviceType::CUDA, device));
|
||||
@ -1255,10 +1253,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.allocation[stat_type], -1);
|
||||
update_stat(
|
||||
stats.allocated_bytes[stat_type],
|
||||
-static_cast<std::int64_t>(block->size));
|
||||
decrease_amount_from_stat(stats.allocation[stat_type], 1);
|
||||
decrease_amount_from_stat(stats.allocated_bytes[stat_type], block->size);
|
||||
});
|
||||
if (record_history) {
|
||||
record_trace(
|
||||
@ -1269,7 +1265,7 @@ class DeviceCachingAllocator {
|
||||
context ? context : block->context_when_allocated);
|
||||
}
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_allocations, -1);
|
||||
decrease_amount_from_stat(stats.oversize_allocations, 1);
|
||||
|
||||
if (!block->stream_uses.empty()) {
|
||||
if (C10_UNLIKELY(captures_underway)) {
|
||||
@ -1287,7 +1283,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
c10::reportMemoryUsageToProfiler(
|
||||
orig_block_ptr,
|
||||
-orig_block_size,
|
||||
-static_cast<int64_t>(orig_block_size),
|
||||
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
c10::Device(c10::DeviceType::CUDA, block->device));
|
||||
@ -1467,7 +1463,7 @@ class DeviceCachingAllocator {
|
||||
void setSegmentStateToCheckpoint(
|
||||
Block* block,
|
||||
SegmentState& segment,
|
||||
std::shared_ptr<GatheredContext> context,
|
||||
const std::shared_ptr<GatheredContext>& context,
|
||||
RestoreResult& rr) {
|
||||
Block* curr_block = block;
|
||||
Block* last_block = block;
|
||||
@ -1497,8 +1493,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
// curr_block will become next pointer if it is split, so reassign with
|
||||
// the returned value
|
||||
curr_block = alloc_found_block(
|
||||
std::move(params), block_state.size, context, split);
|
||||
curr_block = alloc_found_block(params, block_state.size, context, split);
|
||||
|
||||
TORCH_CHECK(curr_block->ptr == block_state.ptr);
|
||||
TORCH_CHECK(curr_block->size == block_state.size);
|
||||
@ -1712,12 +1707,12 @@ class DeviceCachingAllocator {
|
||||
result.reserve(alloc_trace->size());
|
||||
result.insert(
|
||||
result.end(),
|
||||
alloc_trace->begin() + alloc_trace_next,
|
||||
alloc_trace->begin() + static_cast<std::ptrdiff_t>(alloc_trace_next),
|
||||
alloc_trace->end());
|
||||
result.insert(
|
||||
result.end(),
|
||||
alloc_trace->begin(),
|
||||
alloc_trace->begin() + alloc_trace_next);
|
||||
alloc_trace->begin() + static_cast<std::ptrdiff_t>(alloc_trace_next));
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -1977,7 +1972,7 @@ class DeviceCachingAllocator {
|
||||
total_allocated_memory += mapped_range.size;
|
||||
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.reserved_bytes[stat_type], mapped_range.size);
|
||||
add_amount_to_stat(stats.reserved_bytes[stat_type], mapped_range.size);
|
||||
});
|
||||
if (record_history) {
|
||||
record_trace(
|
||||
@ -2052,11 +2047,10 @@ class DeviceCachingAllocator {
|
||||
|
||||
const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
|
||||
for (Block* merge_candidate : merge_candidates) {
|
||||
const int64_t subsumed_size =
|
||||
try_merge_blocks(block, merge_candidate, pool);
|
||||
auto subsumed_size = try_merge_blocks(block, merge_candidate, pool);
|
||||
if (subsumed_size > 0) {
|
||||
net_change_inactive_split_blocks -= 1;
|
||||
net_change_inactive_split_size -= subsumed_size;
|
||||
net_change_inactive_split_size -= static_cast<int64_t>(subsumed_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2068,7 +2062,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
if (block->is_split()) {
|
||||
net_change_inactive_split_blocks += 1;
|
||||
net_change_inactive_split_size += block->size;
|
||||
net_change_inactive_split_size += static_cast<int64_t>(block->size);
|
||||
}
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(pool);
|
||||
@ -2087,13 +2081,11 @@ class DeviceCachingAllocator {
|
||||
stats.inactive_split_bytes[stat_type],
|
||||
net_change_inactive_split_size);
|
||||
}
|
||||
update_stat(stats.active[stat_type], -1);
|
||||
update_stat(
|
||||
stats.active_bytes[stat_type],
|
||||
-static_cast<std::int64_t>(original_block_size));
|
||||
update_stat(
|
||||
stats.requested_bytes[stat_type],
|
||||
-static_cast<std::int64_t>(requested_size));
|
||||
decrease_amount_from_stat(stats.active[stat_type], 1);
|
||||
decrease_amount_from_stat(
|
||||
stats.active_bytes[stat_type], original_block_size);
|
||||
decrease_amount_from_stat(
|
||||
stats.requested_bytes[stat_type], requested_size);
|
||||
});
|
||||
}
|
||||
|
||||
@ -2392,11 +2384,11 @@ class DeviceCachingAllocator {
|
||||
total_allocated_memory += size;
|
||||
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
|
||||
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.segment[stat_type], 1);
|
||||
update_stat(stats.reserved_bytes[stat_type], size);
|
||||
add_amount_to_stat(stats.segment[stat_type], 1);
|
||||
add_amount_to_stat(stats.reserved_bytes[stat_type], size);
|
||||
});
|
||||
if (size >= CUDAAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_segments, 1);
|
||||
add_amount_to_stat(stats.oversize_segments, 1);
|
||||
|
||||
// p.block came from new, not cudaMalloc. It should not be nullptr here.
|
||||
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
|
||||
@ -2518,14 +2510,12 @@ class DeviceCachingAllocator {
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.segment[stat_type], -1);
|
||||
update_stat(
|
||||
stats.reserved_bytes[stat_type],
|
||||
-static_cast<std::int64_t>(block->size));
|
||||
decrease_amount_from_stat(stats.segment[stat_type], 1);
|
||||
decrease_amount_from_stat(stats.reserved_bytes[stat_type], block->size);
|
||||
});
|
||||
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
update_stat(stats.oversize_segments, -1);
|
||||
decrease_amount_from_stat(stats.oversize_segments, 1);
|
||||
if (record_history) {
|
||||
record_trace(
|
||||
TraceEntry::SEGMENT_FREE,
|
||||
@ -2583,7 +2573,7 @@ class DeviceCachingAllocator {
|
||||
total_allocated_memory -= unmapped.size;
|
||||
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
update_stat(stats.reserved_bytes[stat_type], -unmapped.size);
|
||||
decrease_amount_from_stat(stats.reserved_bytes[stat_type], unmapped.size);
|
||||
});
|
||||
if (record_history) {
|
||||
record_trace(
|
||||
@ -3101,7 +3091,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
}
|
||||
|
||||
void enablePeerAccess(int dev, int dev_to_access) override {
|
||||
c10::cuda::CUDAGuard device_guard(dev);
|
||||
c10::cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(dev));
|
||||
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
|
||||
if (err == cudaErrorPeerAccessAlreadyEnabled) {
|
||||
// ignore and clear the error if access was already enabled
|
||||
@ -3174,7 +3164,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
|
||||
auto sp =
|
||||
std::shared_ptr<void>(dev, [handle, curr_device, this](void* ptr) {
|
||||
cuda::CUDAGuard device_guard(curr_device);
|
||||
cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(curr_device));
|
||||
std::lock_guard<std::mutex> deleter_lock(IpcMutex);
|
||||
C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
|
||||
ipcMemHandle_to_devptr.erase(handle);
|
||||
@ -3212,13 +3202,13 @@ std::string format_size(uint64_t size) {
|
||||
if (size <= 1024) {
|
||||
os << size << " bytes";
|
||||
} else if (size <= 1048576) {
|
||||
os << (size / 1024.0);
|
||||
os << (static_cast<double>(size) / 1024.0);
|
||||
os << " KiB";
|
||||
} else if (size <= 1073741824ULL) {
|
||||
os << size / 1048576.0;
|
||||
os << static_cast<double>(size) / 1048576.0f;
|
||||
os << " MiB";
|
||||
} else {
|
||||
os << size / 1073741824.0;
|
||||
os << static_cast<double>(size) / 1073741824.0f;
|
||||
os << " GiB";
|
||||
}
|
||||
return os.str();
|
||||
|
@ -26,8 +26,6 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
|
||||
#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
|
||||
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
|
||||
|
||||
namespace cuda {
|
||||
|
||||
// TODO: Turn this into an honest to goodness class. I briefly attempted to do
|
||||
// this, but it was a bit irritating to figure out how to also correctly
|
||||
// apply pimpl pattern so I didn't have to leak any internal implementation
|
||||
@ -41,15 +39,15 @@ namespace cuda {
|
||||
// not counted as a word boundary, so you would otherwise have to list each
|
||||
// of these functions.
|
||||
|
||||
namespace CUDACachingAllocator {
|
||||
namespace cuda::CUDACachingAllocator {
|
||||
|
||||
extern const size_t kLargeBuffer;
|
||||
|
||||
struct Stat {
|
||||
int64_t current = 0;
|
||||
int64_t peak = 0;
|
||||
int64_t allocated = 0;
|
||||
int64_t freed = 0;
|
||||
uint64_t current = 0;
|
||||
uint64_t peak = 0;
|
||||
uint64_t allocated = 0;
|
||||
uint64_t freed = 0;
|
||||
};
|
||||
|
||||
enum struct StatType : uint64_t {
|
||||
@ -59,7 +57,7 @@ enum struct StatType : uint64_t {
|
||||
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
|
||||
};
|
||||
|
||||
typedef std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)> StatArray;
|
||||
using StatArray = std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)>;
|
||||
|
||||
// Struct containing memory allocator summary statistics for a device.
|
||||
struct DeviceStats {
|
||||
@ -98,10 +96,10 @@ struct DeviceStats {
|
||||
Stat oversize_segments;
|
||||
|
||||
// SIZE: maximum block size that is allowed to be split.
|
||||
int64_t max_split_size = 0;
|
||||
size_t max_split_size = 0;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)(void);
|
||||
using CreateContextFn = std::shared_ptr<GatheredContext> (*)();
|
||||
|
||||
// Struct containing info of an allocation block (i.e. a fractional part of a
|
||||
// cudaMalloc)..
|
||||
@ -123,7 +121,7 @@ struct SegmentInfo {
|
||||
int64_t requested_size = 0; // unrounded, actually requested size
|
||||
int64_t allocated_size = 0;
|
||||
int64_t active_size = 0;
|
||||
cudaStream_t stream = 0;
|
||||
cudaStream_t stream = nullptr;
|
||||
bool is_large = false;
|
||||
bool is_expandable = false;
|
||||
MempoolId_t owner_private_pool_id = {0, 0};
|
||||
@ -167,7 +165,7 @@ struct TraceEntry {
|
||||
int64_t addr_; // for OOM, this is the amount of free bytes reported by cuda
|
||||
std::shared_ptr<GatheredContext> context_;
|
||||
cudaStream_t stream_;
|
||||
int64_t size_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
struct SnapshotInfo {
|
||||
@ -195,9 +193,9 @@ std::string format_size(uint64_t size);
|
||||
|
||||
using OutOfMemoryObserver = std::function<void(
|
||||
int64_t device,
|
||||
int64_t allocated,
|
||||
int64_t device_total,
|
||||
int64_t device_free)>;
|
||||
size_t allocated,
|
||||
size_t device_total,
|
||||
size_t device_free)>;
|
||||
|
||||
class CUDAAllocator : public Allocator {
|
||||
public:
|
||||
@ -223,9 +221,9 @@ class CUDAAllocator : public Allocator {
|
||||
virtual void releasePool(int device, MempoolId_t mempool_id) = 0;
|
||||
// returns true if the allocated blocks are equal to expected live allocations
|
||||
virtual bool checkPoolLiveAllocations(
|
||||
int device,
|
||||
MempoolId_t mempool_id,
|
||||
const std::unordered_set<void*>& expected_live_allocations) {
|
||||
int /*device*/,
|
||||
MempoolId_t /*mempool_id*/,
|
||||
const std::unordered_set<void*>& /*expected_live_allocations*/) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
name(),
|
||||
@ -350,7 +348,7 @@ inline std::shared_ptr<AllocatorState> getCheckpointState(
|
||||
inline CheckpointDelta setCheckpointPoolState(
|
||||
int device,
|
||||
std::shared_ptr<AllocatorState> pps) {
|
||||
return get()->setCheckpointPoolState(device, pps);
|
||||
return get()->setCheckpointPoolState(device, std::move(pps));
|
||||
}
|
||||
|
||||
// CUDAGraph interactions
|
||||
@ -387,7 +385,7 @@ inline bool checkPoolLiveAllocations(
|
||||
}
|
||||
|
||||
inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
|
||||
return get()->attachOutOfMemoryObserver(observer);
|
||||
return get()->attachOutOfMemoryObserver(std::move(observer));
|
||||
}
|
||||
|
||||
inline void releasePool(int device, MempoolId_t mempool_id) {
|
||||
@ -395,7 +393,7 @@ inline void releasePool(int device, MempoolId_t mempool_id) {
|
||||
}
|
||||
// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
|
||||
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||
return get()->getIpcDevPtr(handle);
|
||||
return get()->getIpcDevPtr(std::move(handle));
|
||||
}
|
||||
|
||||
inline std::string name() {
|
||||
@ -418,6 +416,5 @@ inline void enablePeerAccess(int dev, int dev_to_access) {
|
||||
return get()->enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
} // namespace CUDACachingAllocator
|
||||
} // namespace cuda
|
||||
} // namespace cuda::CUDACachingAllocator
|
||||
} // namespace c10
|
||||
|
@ -227,7 +227,7 @@ cudaError_t SetDevice(int device) {
|
||||
}
|
||||
|
||||
cudaError_t MaybeSetDevice(int device) {
|
||||
if (hasPrimaryContext(device)) {
|
||||
if (hasPrimaryContext(static_cast<DeviceIndex>(device))) {
|
||||
return c10::cuda::SetDevice(device);
|
||||
}
|
||||
targetDeviceIndex = device;
|
||||
@ -257,7 +257,7 @@ int MaybeExchangeDevice(int to_device) {
|
||||
if (to_device == cur_device) {
|
||||
return cur_device;
|
||||
}
|
||||
if (hasPrimaryContext(to_device)) {
|
||||
if (hasPrimaryContext(static_cast<DeviceIndex>(to_device))) {
|
||||
C10_CUDA_CHECK(cudaSetDevice(to_device));
|
||||
} else {
|
||||
targetDeviceIndex = to_device;
|
||||
|
@ -31,17 +31,13 @@ namespace {
|
||||
// General helpers
|
||||
|
||||
struct UsageStream {
|
||||
cudaStream_t stream;
|
||||
int device;
|
||||
cudaStream_t stream{nullptr};
|
||||
int device{-1};
|
||||
UsageStream() = default;
|
||||
UsageStream(cudaStream_t s, int d) : stream(s), device(d) {}
|
||||
UsageStream(const UsageStream& us) = default;
|
||||
UsageStream(const UsageStream&& us) : stream(us.stream), device(us.device) {}
|
||||
UsageStream& operator=(UsageStream other) {
|
||||
stream = other.stream;
|
||||
device = other.device;
|
||||
return *this;
|
||||
}
|
||||
UsageStream(UsageStream&& us) noexcept = default;
|
||||
UsageStream& operator=(const UsageStream& other) = default;
|
||||
};
|
||||
|
||||
bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
|
||||
@ -58,10 +54,11 @@ struct PtrUsage {
|
||||
// recorded_streams holds side usage streams added by record_stream calls.
|
||||
// In other words, it does NOT include the original creation stream.
|
||||
ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
|
||||
UsageStream creation_stream{};
|
||||
UsageStream creation_stream;
|
||||
uint64_t size;
|
||||
bool captured;
|
||||
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
|
||||
PtrUsage(UsageStream stream, uint64_t s, bool c)
|
||||
: creation_stream(stream), size(s), captured(c) {}
|
||||
};
|
||||
|
||||
int device_count = 0;
|
||||
@ -148,7 +145,7 @@ bool capture_underway = false;
|
||||
// Assumes the caller holds general_mutex
|
||||
inline void lazy_init_device(int device) {
|
||||
if (!devs_initialized_flags[device]) {
|
||||
CUDAGuard g(device);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(device));
|
||||
|
||||
// See "Retaining memory in the pool" here:
|
||||
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
|
||||
@ -200,7 +197,7 @@ inline void free_impl(PtrInfo::iterator& it) {
|
||||
// If the usage stream is a null (default) stream,
|
||||
// cudaFreeAsync infers the device from the ambient context,
|
||||
// so we need to set the right ambient context.
|
||||
CUDAGuard g(creation_stream.device);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(creation_stream.device));
|
||||
|
||||
if (recorded_streams.empty()) {
|
||||
// ptr was only used on one stream, which must have been
|
||||
@ -240,7 +237,7 @@ inline void free_impl(PtrInfo::iterator& it) {
|
||||
|
||||
// cudaEventRecord requires that the input event and stream are on the
|
||||
// same device.
|
||||
CUDAGuard g_usage(recorded_stream.device);
|
||||
CUDAGuard g_usage(static_cast<DeviceIndex>(recorded_stream.device));
|
||||
|
||||
sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
|
||||
}
|
||||
@ -326,7 +323,7 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
|
||||
// If stream is a null (default) stream,
|
||||
// cudaMallocAsync infers the device from the ambient context,
|
||||
// so we need to set the right ambient context.
|
||||
CUDAGuard g(device);
|
||||
CUDAGuard g(static_cast<DeviceType>(device));
|
||||
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
@ -387,14 +384,12 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(err);
|
||||
}
|
||||
|
||||
auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
|
||||
auto inserted = ptr_info.emplace(
|
||||
*devPtr, PtrUsage(UsageStream{stream, device}, size, capture_underway));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
inserted.second,
|
||||
"address returned by cudaMallocAsync already exists "
|
||||
"in ptr_info");
|
||||
|
||||
inserted.first->second.creation_stream = {stream, device};
|
||||
|
||||
pytorch_used_bytes[device] += size;
|
||||
}
|
||||
|
||||
@ -414,9 +409,17 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
if (size != 0) {
|
||||
mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
|
||||
mallocAsync(
|
||||
&r,
|
||||
device,
|
||||
size,
|
||||
cuda::getCurrentCUDAStream(static_cast<DeviceIndex>(device)));
|
||||
}
|
||||
return {r, r, &local_raw_delete, Device(DeviceType::CUDA, device)};
|
||||
return {
|
||||
r,
|
||||
r,
|
||||
&local_raw_delete,
|
||||
Device(DeviceType::CUDA, static_cast<DeviceIndex>(device))};
|
||||
}
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return &local_raw_delete;
|
||||
@ -459,7 +462,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
assertValidDevice(device);
|
||||
CUDAGuard g(device);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(device));
|
||||
// Should setMemoryFraction be allowed to trigger a full device context and
|
||||
// pool-creating lazy_init_device, or should we simply assert this device is
|
||||
// already initialized, ie
|
||||
@ -486,7 +489,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
|
||||
for (int dev = 0; dev < device_count; dev++) {
|
||||
if (devs_initialized_flags[dev]) {
|
||||
CUDAGuard g(dev);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(dev));
|
||||
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
cudaDeviceGetDefaultMemPool(&mempool, dev);
|
||||
@ -530,7 +533,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// also stabilizes to a point where they all come straight from the pool.
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
assertValidDevice(device);
|
||||
CUDAGuard g(device);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(device));
|
||||
lazy_init_device(device);
|
||||
|
||||
size_t free_upper_bound = 0;
|
||||
@ -666,7 +669,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
if (devs_initialized_flags[device]) {
|
||||
CUDAGuard g(device);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(device));
|
||||
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||
@ -724,7 +727,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
void resetPeakStats(int device) override {
|
||||
assertValidDevice(device);
|
||||
|
||||
CUDAGuard g(device);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(device));
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
|
||||
// Using zero as the reset value is the method recommended by Cuda driver
|
||||
@ -774,13 +777,14 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
"CudaMallocAsync::notifyCaptureAboutToEnd called, "
|
||||
"but CudaMallocAsync::capture_underway is false.");
|
||||
|
||||
auto capture_stream = cuda::getCurrentCUDAStream(device);
|
||||
auto capture_stream =
|
||||
cuda::getCurrentCUDAStream(static_cast<DeviceIndex>(device));
|
||||
|
||||
// See Note [Avoid dangling free streams during CUDA graph capture]
|
||||
for (const auto& free_stream : capture_free_streams) {
|
||||
// cudaEventRecord requires that the input event and stream are on the
|
||||
// same device.
|
||||
CUDAGuard g(free_stream.device);
|
||||
CUDAGuard g(static_cast<DeviceIndex>(free_stream.device));
|
||||
|
||||
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
|
||||
cudaEvent_t event = nullptr;
|
||||
@ -820,7 +824,11 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
int device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
void* r = nullptr;
|
||||
mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
|
||||
mallocAsync(
|
||||
&r,
|
||||
device,
|
||||
nbytes,
|
||||
cuda::getCurrentCUDAStream(static_cast<DeviceIndex>(device)));
|
||||
return r;
|
||||
}
|
||||
|
||||
@ -842,7 +850,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// an error. cudaMallocAsync pools are unaffected by
|
||||
// cudaDeviceEnablePeerAccess. We need pool-specific enablement. See
|
||||
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
|
||||
c10::cuda::CUDAGuard device_guard(dev);
|
||||
c10::cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(dev));
|
||||
cudaMemPool_t mempool = nullptr;
|
||||
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
|
||||
cudaMemAccessDesc desc = {};
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <c10/cuda/CUDAMiscFunctions.h>
|
||||
#include <stdlib.h>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <c10/core/impl/GPUTrace.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
@ -11,6 +10,7 @@
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
// NOLINTBEGIN(*c-arrays*)
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
|
||||
@ -256,6 +256,7 @@ cudaStream_t CUDAStream::stream() const {
|
||||
" official API like c10::cuda::getStreamFromPool() to get a new stream.");
|
||||
return nullptr;
|
||||
} else if (st.isExt()) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<cudaStream_t>(stream_id);
|
||||
} else {
|
||||
auto streamType = st.getStreamType();
|
||||
@ -341,3 +342,4 @@ std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
// NOLINTEND(*c-arrays*)
|
||||
|
@ -339,7 +339,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase {
|
||||
}
|
||||
|
||||
void reportOutOfMemory(
|
||||
int64_t alloc_size,
|
||||
size_t alloc_size,
|
||||
size_t total_allocated,
|
||||
size_t total_reserved,
|
||||
c10::Device device) override {
|
||||
|
@ -807,12 +807,12 @@ PyObject* THCPModule_attachOutOfMemoryObserver(
|
||||
Py_XINCREF(observer);
|
||||
auto obs = [observer](
|
||||
int64_t device,
|
||||
int64_t alloc,
|
||||
int64_t device_allocated,
|
||||
int64_t device_free) {
|
||||
size_t alloc,
|
||||
size_t device_allocated,
|
||||
size_t device_free) {
|
||||
py::gil_scoped_acquire g;
|
||||
PyObject* result = PyObject_CallFunction(
|
||||
observer, "LLLL", device, alloc, device_allocated, device_free);
|
||||
observer, "LKKK", device, alloc, device_allocated, device_free);
|
||||
if (!result) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ struct ExtraFields<EventType::Allocation> : RawAllocation {
|
||||
template <>
|
||||
struct ExtraFields<EventType::OutOfMemory> {
|
||||
torch::profiler::impl::approx_time_t start_time_;
|
||||
int64_t alloc_size_;
|
||||
size_t alloc_size_;
|
||||
size_t total_allocated_;
|
||||
size_t total_reserved_;
|
||||
c10::DeviceType device_type_;
|
||||
|
Reference in New Issue
Block a user