[1/N] Apply clang-tidy to c10 cuda files (#111137)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111137
Approved by: https://github.com/zou3519, https://github.com/Skylion007
This commit is contained in:
cyy
2023-10-17 04:52:47 +00:00
committed by PyTorch MergeBot
parent 46000bede6
commit 43b023694e
12 changed files with 149 additions and 155 deletions

View File

@ -58,7 +58,7 @@ void reportMemoryUsageToProfiler(
}
void reportOutOfMemoryToProfiler(
int64_t alloc_size,
size_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device) {
@ -73,7 +73,7 @@ void reportOutOfMemoryToProfiler(
MemoryReportingInfoBase::MemoryReportingInfoBase() = default;
void MemoryReportingInfoBase::reportOutOfMemory(
int64_t /*alloc_size*/,
size_t /*alloc_size*/,
size_t /*total_allocated*/,
size_t /*total_reserved*/,
Device /*device*/) {}

View File

@ -247,7 +247,7 @@ struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
Device device) = 0;
virtual void reportOutOfMemory(
int64_t alloc_size,
size_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device);
@ -264,7 +264,7 @@ C10_API void reportMemoryUsageToProfiler(
Device device);
C10_API void reportOutOfMemoryToProfiler(
int64_t alloc_size,
size_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device);

View File

@ -272,10 +272,7 @@ void ProfiledCPUMemoryReporter::OutOfMemory(size_t nbytes) {
}
if (profile_memory) {
reportOutOfMemoryToProfiler(
static_cast<int64_t>(nbytes),
allocated,
0,
c10::Device(c10::DeviceType::CPU));
nbytes, allocated, 0, c10::Device(c10::DeviceType::CPU));
}
}

View File

@ -132,19 +132,24 @@ using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
void update_stat(Stat& stat, int64_t amount) {
void add_amount_to_stat(Stat& stat, size_t amount) {
stat.current += amount;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
stat.current >= 0,
"Negative tracked stat in CUDA allocator (likely logic error).");
stat.peak = std::max(stat.current, stat.peak);
if (amount > 0) {
stat.allocated += amount;
}
if (amount < 0) {
stat.freed += -amount;
stat.allocated += amount;
}
void decrease_amount_from_stat(Stat& stat, size_t amount) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
amount <= stat.current,
"Negative tracked stat in CUDA allocator (likely logic error).");
stat.current -= amount;
stat.freed += amount;
}
void update_stat(Stat& stat, int64_t amount) {
if (amount >= 0) {
add_amount_to_stat(stat, amount);
} else {
decrease_amount_from_stat(stat, -amount);
}
}
@ -166,13 +171,13 @@ void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
}
}
void update_stat_array(
void decrease_stat_array(
StatArray& stat_array,
int64_t amount,
size_t amount,
const StatTypes& stat_types) {
for_each_selected_stat_type(
stat_types, [&stat_array, amount](size_t stat_type) {
update_stat(stat_array[stat_type], amount);
decrease_amount_from_stat(stat_array[stat_type], amount);
});
}
@ -190,6 +195,7 @@ struct BlockPool {
owner_PrivatePool(private_pool) {}
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const bool is_small;
PrivatePool* owner_PrivatePool;
};
@ -451,6 +457,7 @@ struct ExpandableSegment {
}
char* ptr() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return (char*)ptr_;
}
size_t size() const {
@ -503,8 +510,8 @@ struct ExpandableSegment {
handles_.pop_back();
}
}
void forEachAllocatedRange(std::function<void(size_t, size_t)> fn) {
auto start = 0;
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
size_t start = 0;
for (auto i : c10::irange(handles_.size())) {
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
start = i;
@ -815,15 +822,15 @@ static std::string reportProcessMemoryInfo(int device) {
std::vector<nvmlProcessInfo_v1_t> procs(8);
unsigned int size = procs.size();
nvmlReturn_t r;
nvmlReturn_t r{NVML_ERROR_UNKNOWN};
while ((r = DriverAPI::get()->nvmlDeviceGetComputeRunningProcesses_(
nvml_device, &size, procs.data())) ==
NVML_ERROR_INSUFFICIENT_SIZE) {
procs.resize(size);
}
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == r);
unsigned int self_pid = getpid();
std::stringstream ss;
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == r);
ss << "";
for (auto i : c10::irange(size)) {
auto& proc = procs[i];
@ -884,7 +891,7 @@ class DeviceCachingAllocator {
bool set_fraction = false;
bool record_history = false;
std::atomic<CreateContextFn> context_recorder_;
std::atomic<CreateContextFn> context_recorder_{};
size_t alloc_trace_next = 0;
RecordContext record_context_ = RecordContext::NEVER;
size_t alloc_trace_max_entries_ = 1;
@ -1065,9 +1072,9 @@ class DeviceCachingAllocator {
c10::reportOutOfMemoryToProfiler(
size,
stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current,
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current,
c10::Device(c10::DeviceType::CUDA, static_cast<DeviceIndex>(device)));
@ -1136,11 +1143,11 @@ class DeviceCachingAllocator {
bool split_remainder = should_split(params.block, params.size());
return alloc_found_block(
std::move(params), orig_size, std::move(context), split_remainder);
params, orig_size, std::move(context), split_remainder);
}
Block* alloc_found_block(
AllocParams params,
const AllocParams& params,
size_t orig_size,
std::shared_ptr<GatheredContext> context,
bool split_remainder) {
@ -1175,28 +1182,24 @@ class DeviceCachingAllocator {
if (already_split && !block->expandable_segment_) {
// An already-split inactive block is being shrunk by size bytes.
update_stat_array(
stats.inactive_split_bytes,
-static_cast<std::int64_t>(block->size),
params.stat_types);
decrease_stat_array(
stats.inactive_split_bytes, block->size, params.stat_types);
} else if (!block->expandable_segment_) {
// A new split inactive block is being created from a previously unsplit
// block, size remaining->size bytes.
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(
stats.inactive_split_bytes[stat_type],
static_cast<std::int64_t>(remaining->size));
update_stat(stats.inactive_split[stat_type], 1);
add_amount_to_stat(
stats.inactive_split_bytes[stat_type], remaining->size);
add_amount_to_stat(stats.inactive_split[stat_type], 1);
});
}
} else if (already_split && !block->expandable_segment_) {
// An already-split block is becoming active
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(
stats.inactive_split_bytes[stat_type],
-static_cast<std::int64_t>(block->size));
update_stat(stats.inactive_split[stat_type], -1);
decrease_amount_from_stat(
stats.inactive_split_bytes[stat_type], block->size);
decrease_amount_from_stat(stats.inactive_split[stat_type], 1);
});
}
@ -1216,24 +1219,19 @@ class DeviceCachingAllocator {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], 1);
update_stat(
stats.allocated_bytes[stat_type],
static_cast<std::int64_t>(block->size));
update_stat(stats.active[stat_type], 1);
update_stat(
stats.active_bytes[stat_type],
static_cast<std::int64_t>(block->size));
update_stat(
stats.requested_bytes[stat_type],
static_cast<std::int64_t>(block->requested_size));
add_amount_to_stat(stats.allocation[stat_type], 1);
add_amount_to_stat(stats.allocated_bytes[stat_type], block->size);
add_amount_to_stat(stats.active[stat_type], 1);
add_amount_to_stat(stats.active_bytes[stat_type], block->size);
add_amount_to_stat(
stats.requested_bytes[stat_type], block->requested_size);
});
if (block->size >= CUDAAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, 1);
add_amount_to_stat(stats.oversize_allocations, 1);
c10::reportMemoryUsageToProfiler(
block->ptr,
block->size,
static_cast<int64_t>(block->size),
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::CUDA, device));
@ -1255,10 +1253,8 @@ class DeviceCachingAllocator {
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], -1);
update_stat(
stats.allocated_bytes[stat_type],
-static_cast<std::int64_t>(block->size));
decrease_amount_from_stat(stats.allocation[stat_type], 1);
decrease_amount_from_stat(stats.allocated_bytes[stat_type], block->size);
});
if (record_history) {
record_trace(
@ -1269,7 +1265,7 @@ class DeviceCachingAllocator {
context ? context : block->context_when_allocated);
}
if (block->size >= CUDAAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, -1);
decrease_amount_from_stat(stats.oversize_allocations, 1);
if (!block->stream_uses.empty()) {
if (C10_UNLIKELY(captures_underway)) {
@ -1287,7 +1283,7 @@ class DeviceCachingAllocator {
c10::reportMemoryUsageToProfiler(
orig_block_ptr,
-orig_block_size,
-static_cast<int64_t>(orig_block_size),
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::CUDA, block->device));
@ -1467,7 +1463,7 @@ class DeviceCachingAllocator {
void setSegmentStateToCheckpoint(
Block* block,
SegmentState& segment,
std::shared_ptr<GatheredContext> context,
const std::shared_ptr<GatheredContext>& context,
RestoreResult& rr) {
Block* curr_block = block;
Block* last_block = block;
@ -1497,8 +1493,7 @@ class DeviceCachingAllocator {
// curr_block will become next pointer if it is split, so reassign with
// the returned value
curr_block = alloc_found_block(
std::move(params), block_state.size, context, split);
curr_block = alloc_found_block(params, block_state.size, context, split);
TORCH_CHECK(curr_block->ptr == block_state.ptr);
TORCH_CHECK(curr_block->size == block_state.size);
@ -1712,12 +1707,12 @@ class DeviceCachingAllocator {
result.reserve(alloc_trace->size());
result.insert(
result.end(),
alloc_trace->begin() + alloc_trace_next,
alloc_trace->begin() + static_cast<std::ptrdiff_t>(alloc_trace_next),
alloc_trace->end());
result.insert(
result.end(),
alloc_trace->begin(),
alloc_trace->begin() + alloc_trace_next);
alloc_trace->begin() + static_cast<std::ptrdiff_t>(alloc_trace_next));
return result;
}
@ -1977,7 +1972,7 @@ class DeviceCachingAllocator {
total_allocated_memory += mapped_range.size;
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.reserved_bytes[stat_type], mapped_range.size);
add_amount_to_stat(stats.reserved_bytes[stat_type], mapped_range.size);
});
if (record_history) {
record_trace(
@ -2052,11 +2047,10 @@ class DeviceCachingAllocator {
const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
for (Block* merge_candidate : merge_candidates) {
const int64_t subsumed_size =
try_merge_blocks(block, merge_candidate, pool);
auto subsumed_size = try_merge_blocks(block, merge_candidate, pool);
if (subsumed_size > 0) {
net_change_inactive_split_blocks -= 1;
net_change_inactive_split_size -= subsumed_size;
net_change_inactive_split_size -= static_cast<int64_t>(subsumed_size);
}
}
@ -2068,7 +2062,7 @@ class DeviceCachingAllocator {
if (block->is_split()) {
net_change_inactive_split_blocks += 1;
net_change_inactive_split_size += block->size;
net_change_inactive_split_size += static_cast<int64_t>(block->size);
}
StatTypes stat_types = get_stat_types_for_pool(pool);
@ -2087,13 +2081,11 @@ class DeviceCachingAllocator {
stats.inactive_split_bytes[stat_type],
net_change_inactive_split_size);
}
update_stat(stats.active[stat_type], -1);
update_stat(
stats.active_bytes[stat_type],
-static_cast<std::int64_t>(original_block_size));
update_stat(
stats.requested_bytes[stat_type],
-static_cast<std::int64_t>(requested_size));
decrease_amount_from_stat(stats.active[stat_type], 1);
decrease_amount_from_stat(
stats.active_bytes[stat_type], original_block_size);
decrease_amount_from_stat(
stats.requested_bytes[stat_type], requested_size);
});
}
@ -2392,11 +2384,11 @@ class DeviceCachingAllocator {
total_allocated_memory += size;
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], 1);
update_stat(stats.reserved_bytes[stat_type], size);
add_amount_to_stat(stats.segment[stat_type], 1);
add_amount_to_stat(stats.reserved_bytes[stat_type], size);
});
if (size >= CUDAAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, 1);
add_amount_to_stat(stats.oversize_segments, 1);
// p.block came from new, not cudaMalloc. It should not be nullptr here.
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
@ -2518,14 +2510,12 @@ class DeviceCachingAllocator {
StatTypes stat_types = get_stat_types_for_pool(*pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], -1);
update_stat(
stats.reserved_bytes[stat_type],
-static_cast<std::int64_t>(block->size));
decrease_amount_from_stat(stats.segment[stat_type], 1);
decrease_amount_from_stat(stats.reserved_bytes[stat_type], block->size);
});
if (block->size >= CUDAAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, -1);
decrease_amount_from_stat(stats.oversize_segments, 1);
if (record_history) {
record_trace(
TraceEntry::SEGMENT_FREE,
@ -2583,7 +2573,7 @@ class DeviceCachingAllocator {
total_allocated_memory -= unmapped.size;
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.reserved_bytes[stat_type], -unmapped.size);
decrease_amount_from_stat(stats.reserved_bytes[stat_type], unmapped.size);
});
if (record_history) {
record_trace(
@ -3101,7 +3091,7 @@ class NativeCachingAllocator : public CUDAAllocator {
}
void enablePeerAccess(int dev, int dev_to_access) override {
c10::cuda::CUDAGuard device_guard(dev);
c10::cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(dev));
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
if (err == cudaErrorPeerAccessAlreadyEnabled) {
// ignore and clear the error if access was already enabled
@ -3174,7 +3164,7 @@ class NativeCachingAllocator : public CUDAAllocator {
C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
auto sp =
std::shared_ptr<void>(dev, [handle, curr_device, this](void* ptr) {
cuda::CUDAGuard device_guard(curr_device);
cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(curr_device));
std::lock_guard<std::mutex> deleter_lock(IpcMutex);
C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
ipcMemHandle_to_devptr.erase(handle);
@ -3212,13 +3202,13 @@ std::string format_size(uint64_t size) {
if (size <= 1024) {
os << size << " bytes";
} else if (size <= 1048576) {
os << (size / 1024.0);
os << (static_cast<double>(size) / 1024.0);
os << " KiB";
} else if (size <= 1073741824ULL) {
os << size / 1048576.0;
os << static_cast<double>(size) / 1048576.0f;
os << " MiB";
} else {
os << size / 1073741824.0;
os << static_cast<double>(size) / 1073741824.0f;
os << " GiB";
}
return os.str();

View File

@ -26,8 +26,6 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
namespace cuda {
// TODO: Turn this into an honest to goodness class. I briefly attempted to do
// this, but it was a bit irritating to figure out how to also correctly
// apply pimpl pattern so I didn't have to leak any internal implementation
@ -41,15 +39,15 @@ namespace cuda {
// not counted as a word boundary, so you would otherwise have to list each
// of these functions.
namespace CUDACachingAllocator {
namespace cuda::CUDACachingAllocator {
extern const size_t kLargeBuffer;
struct Stat {
int64_t current = 0;
int64_t peak = 0;
int64_t allocated = 0;
int64_t freed = 0;
uint64_t current = 0;
uint64_t peak = 0;
uint64_t allocated = 0;
uint64_t freed = 0;
};
enum struct StatType : uint64_t {
@ -59,7 +57,7 @@ enum struct StatType : uint64_t {
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
};
typedef std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)> StatArray;
using StatArray = std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)>;
// Struct containing memory allocator summary statistics for a device.
struct DeviceStats {
@ -98,10 +96,10 @@ struct DeviceStats {
Stat oversize_segments;
// SIZE: maximum block size that is allowed to be split.
int64_t max_split_size = 0;
size_t max_split_size = 0;
};
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)(void);
using CreateContextFn = std::shared_ptr<GatheredContext> (*)();
// Struct containing info of an allocation block (i.e. a fractional part of a
// cudaMalloc)..
@ -123,7 +121,7 @@ struct SegmentInfo {
int64_t requested_size = 0; // unrounded, actually requested size
int64_t allocated_size = 0;
int64_t active_size = 0;
cudaStream_t stream = 0;
cudaStream_t stream = nullptr;
bool is_large = false;
bool is_expandable = false;
MempoolId_t owner_private_pool_id = {0, 0};
@ -167,7 +165,7 @@ struct TraceEntry {
int64_t addr_; // for OOM, this is the amount of free bytes reported by cuda
std::shared_ptr<GatheredContext> context_;
cudaStream_t stream_;
int64_t size_;
size_t size_;
};
struct SnapshotInfo {
@ -195,9 +193,9 @@ std::string format_size(uint64_t size);
using OutOfMemoryObserver = std::function<void(
int64_t device,
int64_t allocated,
int64_t device_total,
int64_t device_free)>;
size_t allocated,
size_t device_total,
size_t device_free)>;
class CUDAAllocator : public Allocator {
public:
@ -223,9 +221,9 @@ class CUDAAllocator : public Allocator {
virtual void releasePool(int device, MempoolId_t mempool_id) = 0;
// returns true if the allocated blocks are equal to expected live allocations
virtual bool checkPoolLiveAllocations(
int device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
int /*device*/,
MempoolId_t /*mempool_id*/,
const std::unordered_set<void*>& /*expected_live_allocations*/) {
TORCH_CHECK(
false,
name(),
@ -350,7 +348,7 @@ inline std::shared_ptr<AllocatorState> getCheckpointState(
inline CheckpointDelta setCheckpointPoolState(
int device,
std::shared_ptr<AllocatorState> pps) {
return get()->setCheckpointPoolState(device, pps);
return get()->setCheckpointPoolState(device, std::move(pps));
}
// CUDAGraph interactions
@ -387,7 +385,7 @@ inline bool checkPoolLiveAllocations(
}
inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
return get()->attachOutOfMemoryObserver(observer);
return get()->attachOutOfMemoryObserver(std::move(observer));
}
inline void releasePool(int device, MempoolId_t mempool_id) {
@ -395,7 +393,7 @@ inline void releasePool(int device, MempoolId_t mempool_id) {
}
// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
return get()->getIpcDevPtr(handle);
return get()->getIpcDevPtr(std::move(handle));
}
inline std::string name() {
@ -418,6 +416,5 @@ inline void enablePeerAccess(int dev, int dev_to_access) {
return get()->enablePeerAccess(dev, dev_to_access);
}
} // namespace CUDACachingAllocator
} // namespace cuda
} // namespace cuda::CUDACachingAllocator
} // namespace c10

View File

@ -227,7 +227,7 @@ cudaError_t SetDevice(int device) {
}
cudaError_t MaybeSetDevice(int device) {
if (hasPrimaryContext(device)) {
if (hasPrimaryContext(static_cast<DeviceIndex>(device))) {
return c10::cuda::SetDevice(device);
}
targetDeviceIndex = device;
@ -257,7 +257,7 @@ int MaybeExchangeDevice(int to_device) {
if (to_device == cur_device) {
return cur_device;
}
if (hasPrimaryContext(to_device)) {
if (hasPrimaryContext(static_cast<DeviceIndex>(to_device))) {
C10_CUDA_CHECK(cudaSetDevice(to_device));
} else {
targetDeviceIndex = to_device;

View File

@ -31,17 +31,13 @@ namespace {
// General helpers
struct UsageStream {
cudaStream_t stream;
int device;
cudaStream_t stream{nullptr};
int device{-1};
UsageStream() = default;
UsageStream(cudaStream_t s, int d) : stream(s), device(d) {}
UsageStream(const UsageStream& us) = default;
UsageStream(const UsageStream&& us) : stream(us.stream), device(us.device) {}
UsageStream& operator=(UsageStream other) {
stream = other.stream;
device = other.device;
return *this;
}
UsageStream(UsageStream&& us) noexcept = default;
UsageStream& operator=(const UsageStream& other) = default;
};
bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
@ -58,10 +54,11 @@ struct PtrUsage {
// recorded_streams holds side usage streams added by record_stream calls.
// In other words, it does NOT include the original creation stream.
ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
UsageStream creation_stream{};
UsageStream creation_stream;
uint64_t size;
bool captured;
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
PtrUsage(UsageStream stream, uint64_t s, bool c)
: creation_stream(stream), size(s), captured(c) {}
};
int device_count = 0;
@ -148,7 +145,7 @@ bool capture_underway = false;
// Assumes the caller holds general_mutex
inline void lazy_init_device(int device) {
if (!devs_initialized_flags[device]) {
CUDAGuard g(device);
CUDAGuard g(static_cast<DeviceIndex>(device));
// See "Retaining memory in the pool" here:
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
@ -200,7 +197,7 @@ inline void free_impl(PtrInfo::iterator& it) {
// If the usage stream is a null (default) stream,
// cudaFreeAsync infers the device from the ambient context,
// so we need to set the right ambient context.
CUDAGuard g(creation_stream.device);
CUDAGuard g(static_cast<DeviceIndex>(creation_stream.device));
if (recorded_streams.empty()) {
// ptr was only used on one stream, which must have been
@ -240,7 +237,7 @@ inline void free_impl(PtrInfo::iterator& it) {
// cudaEventRecord requires that the input event and stream are on the
// same device.
CUDAGuard g_usage(recorded_stream.device);
CUDAGuard g_usage(static_cast<DeviceIndex>(recorded_stream.device));
sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
}
@ -326,7 +323,7 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
// If stream is a null (default) stream,
// cudaMallocAsync infers the device from the ambient context,
// so we need to set the right ambient context.
CUDAGuard g(device);
CUDAGuard g(static_cast<DeviceType>(device));
std::lock_guard<std::mutex> lk(general_mutex);
@ -387,14 +384,12 @@ void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
C10_CUDA_CHECK(err);
}
auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
auto inserted = ptr_info.emplace(
*devPtr, PtrUsage(UsageStream{stream, device}, size, capture_underway));
TORCH_INTERNAL_ASSERT(
inserted.second,
"address returned by cudaMallocAsync already exists "
"in ptr_info");
inserted.first->second.creation_stream = {stream, device};
pytorch_used_bytes[device] += size;
}
@ -414,9 +409,17 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
if (size != 0) {
mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
mallocAsync(
&r,
device,
size,
cuda::getCurrentCUDAStream(static_cast<DeviceIndex>(device)));
}
return {r, r, &local_raw_delete, Device(DeviceType::CUDA, device)};
return {
r,
r,
&local_raw_delete,
Device(DeviceType::CUDA, static_cast<DeviceIndex>(device))};
}
DeleterFnPtr raw_deleter() const override {
return &local_raw_delete;
@ -459,7 +462,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
std::lock_guard<std::mutex> lk(general_mutex);
assertValidDevice(device);
CUDAGuard g(device);
CUDAGuard g(static_cast<DeviceIndex>(device));
// Should setMemoryFraction be allowed to trigger a full device context and
// pool-creating lazy_init_device, or should we simply assert this device is
// already initialized, ie
@ -486,7 +489,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
for (int dev = 0; dev < device_count; dev++) {
if (devs_initialized_flags[dev]) {
CUDAGuard g(dev);
CUDAGuard g(static_cast<DeviceIndex>(dev));
cudaMemPool_t mempool = nullptr;
cudaDeviceGetDefaultMemPool(&mempool, dev);
@ -530,7 +533,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// also stabilizes to a point where they all come straight from the pool.
std::lock_guard<std::mutex> lk(general_mutex);
assertValidDevice(device);
CUDAGuard g(device);
CUDAGuard g(static_cast<DeviceIndex>(device));
lazy_init_device(device);
size_t free_upper_bound = 0;
@ -666,7 +669,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
std::lock_guard<std::mutex> lk(general_mutex);
if (devs_initialized_flags[device]) {
CUDAGuard g(device);
CUDAGuard g(static_cast<DeviceIndex>(device));
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
@ -724,7 +727,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
void resetPeakStats(int device) override {
assertValidDevice(device);
CUDAGuard g(device);
CUDAGuard g(static_cast<DeviceIndex>(device));
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
// Using zero as the reset value is the method recommended by Cuda driver
@ -774,13 +777,14 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
"CudaMallocAsync::notifyCaptureAboutToEnd called, "
"but CudaMallocAsync::capture_underway is false.");
auto capture_stream = cuda::getCurrentCUDAStream(device);
auto capture_stream =
cuda::getCurrentCUDAStream(static_cast<DeviceIndex>(device));
// See Note [Avoid dangling free streams during CUDA graph capture]
for (const auto& free_stream : capture_free_streams) {
// cudaEventRecord requires that the input event and stream are on the
// same device.
CUDAGuard g(free_stream.device);
CUDAGuard g(static_cast<DeviceIndex>(free_stream.device));
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
cudaEvent_t event = nullptr;
@ -820,7 +824,11 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
int device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* r = nullptr;
mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
mallocAsync(
&r,
device,
nbytes,
cuda::getCurrentCUDAStream(static_cast<DeviceIndex>(device)));
return r;
}
@ -842,7 +850,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// an error. cudaMallocAsync pools are unaffected by
// cudaDeviceEnablePeerAccess. We need pool-specific enablement. See
// https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
c10::cuda::CUDAGuard device_guard(dev);
c10::cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(dev));
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
cudaMemAccessDesc desc = {};

View File

@ -1,5 +1,5 @@
#include <c10/cuda/CUDAMiscFunctions.h>
#include <stdlib.h>
#include <cstdlib>
namespace c10 {
namespace cuda {

View File

@ -1,5 +1,4 @@
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/CallOnce.h>
@ -11,6 +10,7 @@
#include <mutex>
#include <vector>
// NOLINTBEGIN(*c-arrays*)
namespace c10 {
namespace cuda {
@ -256,6 +256,7 @@ cudaStream_t CUDAStream::stream() const {
" official API like c10::cuda::getStreamFromPool() to get a new stream.");
return nullptr;
} else if (st.isExt()) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<cudaStream_t>(stream_id);
} else {
auto streamType = st.getStreamType();
@ -341,3 +342,4 @@ std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
} // namespace cuda
} // namespace c10
// NOLINTEND(*c-arrays*)

View File

@ -339,7 +339,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase {
}
void reportOutOfMemory(
int64_t alloc_size,
size_t alloc_size,
size_t total_allocated,
size_t total_reserved,
c10::Device device) override {

View File

@ -807,12 +807,12 @@ PyObject* THCPModule_attachOutOfMemoryObserver(
Py_XINCREF(observer);
auto obs = [observer](
int64_t device,
int64_t alloc,
int64_t device_allocated,
int64_t device_free) {
size_t alloc,
size_t device_allocated,
size_t device_free) {
py::gil_scoped_acquire g;
PyObject* result = PyObject_CallFunction(
observer, "LLLL", device, alloc, device_allocated, device_free);
observer, "LKKK", device, alloc, device_allocated, device_free);
if (!result) {
throw py::error_already_set();
}

View File

@ -207,7 +207,7 @@ struct ExtraFields<EventType::Allocation> : RawAllocation {
template <>
struct ExtraFields<EventType::OutOfMemory> {
torch::profiler::impl::approx_time_t start_time_;
int64_t alloc_size_;
size_t alloc_size_;
size_t total_allocated_;
size_t total_reserved_;
c10::DeviceType device_type_;