mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add memory reporting for XPU to Memory Profiler (#152842)
Adds support for XPU profile_memory in Pytorch Profiler. Currently, when `profile_memory=True` is passed to `torch.profiler.profile`, there is no XPU memory reported. For example, the profiling table printed by the code below is missing any `XPU Mem` columns: <details><summary>profiling.py</summary> <p> ```python import torch import torch.nn as nn import torch.optim as optim from torch.profiler import profile, ProfilerActivity class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.conv1 = nn.Conv1d(20,20,15,padding="same") self.flatten = nn.Flatten() self.net1 = nn.Linear(2048, 4096) self.relu = nn.ReLU() self.net2 = nn.Linear(4096, 5) def forward(self, x): res = self.conv1(x) res = self.flatten(res) res = self.net1(res) return self.net2(self.relu(res)) def demo_basic(): model = ToyModel().to("xpu") loss_fn = nn.MSELoss().to("xpu") optimizer = optim.SGD(model.parameters(), lr=0.001) with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], profile_memory=True) as prof: for epoch in range(10): optimizer.zero_grad() outputs = model(torch.randn(20, 2048).to("xpu")) labels = torch.randn(20, 5).to("xpu") loss_fn(outputs, labels).backward() optimizer.step() print(prof.key_averages().table(max_name_column_width=100, sort_by="xpu_time_total", row_limit=100)) if __name__ == "__main__": demo_basic() ``` </p> </details> ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self XPU Self XPU % XPU total XPU time avg CPU Mem Self CPU Mem # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ gemm_kernel 0.00% 0.000us 0.00% 0.000us 0.000us 1.501ms 44.73% 1.501ms 25.024us 0 b 0 b 60 autograd::engine::evaluate_function: AddmmBackward0 0.12% 1.067ms 30.47% 260.929ms 13.046ms 0.000us 0.00% 1.009ms 50.448us 0 b 0 b 20 AddmmBackward0 0.09% 744.983us 15.99% 136.944ms 6.847ms 0.000us 0.00% 784.640us 39.232us 0 b 0 b 20 aten::mm 15.41% 131.956ms 15.79% 135.167ms 3.379ms 784.640us 23.37% 784.640us 19.616us 0 b 0 b 40 aten::linear 0.02% 156.361us 20.58% 176.187ms 8.809ms 0.000us 0.00% 741.760us 37.088us 0 b 0 b 20 aten::addmm 20.25% 173.371ms 20.52% 175.723ms 8.786ms 741.760us 22.10% 741.760us 37.088us 0 b 0 b 20 Optimizer.step#SGD.step 0.40% 3.429ms 5.55% 47.509ms 4.751ms 0.000us 0.00% 488.960us 48.896us 0 b 0 b 10 aten::_foreach_add_ 4.81% 41.162ms 5.15% 44.080ms 4.408ms 488.960us 14.57% 488.960us 48.896us 0 b 0 b 10 at::native::xpu::MultiTensorApplyKernelFunctor<at::n... 0.00% 0.000us 0.00% 0.000us 0.000us 422.880us 12.60% 422.880us 42.288us 0 b 0 b 10 autograd::engine::evaluate_function: ConvolutionBack... 0.03% 280.041us 4.36% 37.328ms 3.733ms 0.000us 0.00% 356.320us 35.632us 0 b 0 b 10 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 856.227ms Self XPU time total: 3.357ms ``` This PR updates the XPUCachingAllocator.cpp to report allocation events to the Profiler, and causes these to be printed in the table: ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self XPU Self XPU % XPU total XPU time avg CPU Mem Self CPU Mem XPU Mem Self XPU Mem # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ gemm_kernel 0.00% 0.000us 0.00% 0.000us 0.000us 1.436ms 43.64% 1.436ms 23.939us 0 b 0 b 0 b 0 b 60 autograd::engine::evaluate_function: AddmmBackward0 0.13% 1.186ms 29.92% 262.875ms 13.144ms 0.000us 0.00% 1.005ms 50.272us 0 b 0 b 320.94 Mb -4.69 Mb 20 AddmmBackward0 0.09% 815.288us 16.48% 144.802ms 7.240ms 0.000us 0.00% 790.720us 39.536us 0 b 0 b 325.47 Mb 0 b 20 aten::mm 15.86% 139.342ms 16.26% 142.875ms 3.572ms 790.720us 24.03% 790.720us 19.768us 0 b 0 b 325.47 Mb 325.47 Mb 40 aten::linear 0.02% 182.856us 20.46% 179.775ms 8.989ms 0.000us 0.00% 669.440us 33.472us 0 b 0 b 3.13 Mb 0 b 20 aten::addmm 20.10% 176.607ms 20.40% 179.210ms 8.961ms 669.440us 20.34% 669.440us 33.472us 0 b 0 b 3.13 Mb 3.13 Mb 20 Optimizer.step#SGD.step 0.42% 3.692ms 5.61% 49.267ms 4.927ms 0.000us 0.00% 486.640us 48.664us 0 b 0 b 0 b 0 b 10 aten::_foreach_add_ 4.83% 42.439ms 5.19% 45.574ms 4.557ms 486.640us 14.79% 486.640us 48.664us 0 b 0 b 0 b -20.00 Kb 10 at::native::xpu::MultiTensorApplyKernelFunctor<at::n... 0.00% 0.000us 0.00% 0.000us 0.000us 420.960us 12.79% 420.960us 42.096us 0 b 0 b 0 b 0 b 10 autograd::engine::evaluate_function: ConvolutionBack... 0.04% 310.719us 4.47% 39.279ms 3.928ms 0.000us 0.00% 339.520us 33.952us 0 b 0 b -2.89 Mb -3.12 Mb 10 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 878.627ms Self XPU time total: 3.291ms ``` These XPU memory numbers match the same profiling results on CUDA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152842 Approved by: https://github.com/guangyey, https://github.com/sraikund16
This commit is contained in:
committed by
PyTorch MergeBot
parent
8817e5ac80
commit
fe49b11e09
@ -122,6 +122,7 @@ list(APPEND ATen_XPU_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/xpu_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/xpu_event_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/xpu_generator_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/xpu_reportMemoryUsage_test.cpp
|
||||
)
|
||||
|
||||
# ---[ Send the lists to the parent scope.
|
||||
|
69
aten/src/ATen/test/xpu_reportMemoryUsage_test.cpp
Normal file
69
aten/src/ATen/test/xpu_reportMemoryUsage_test.cpp
Normal file
@ -0,0 +1,69 @@
|
||||
#include <ATen/test/reportMemoryUsage.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
|
||||
TEST(DeviceCachingAllocator, check_reporter) {
|
||||
auto reporter = std::make_shared<TestMemoryReportingInfo>();
|
||||
c10::DebugInfoGuard guard(c10::DebugInfoKind::PROFILER_STATE, reporter);
|
||||
|
||||
auto _200kb = 200 * 1024;
|
||||
auto _500mb = 500 * 1024 * 1024;
|
||||
|
||||
auto allocator = c10::xpu::XPUCachingAllocator::get();
|
||||
|
||||
auto alloc1 = allocator->allocate(_200kb);
|
||||
auto r = reporter->getLatestRecord();
|
||||
EXPECT_EQ(alloc1.get(), r.ptr);
|
||||
EXPECT_LE(_200kb, r.alloc_size);
|
||||
EXPECT_LE(_200kb, r.total_allocated);
|
||||
EXPECT_LE(_200kb, r.total_reserved);
|
||||
EXPECT_TRUE(r.device.is_xpu());
|
||||
|
||||
auto alloc1_true_ptr = r.ptr;
|
||||
auto alloc1_true_alloc_size = r.alloc_size;
|
||||
|
||||
// I bet pytorch will not waste that much memory
|
||||
EXPECT_LT(r.total_allocated, 2 * _200kb);
|
||||
// I bet pytorch will not reserve that much memory
|
||||
EXPECT_LT(r.total_reserved, _500mb);
|
||||
|
||||
auto alloc2 = allocator->allocate(_500mb);
|
||||
r = reporter->getLatestRecord();
|
||||
EXPECT_EQ(alloc2.get(), r.ptr);
|
||||
EXPECT_LE(_500mb, r.alloc_size);
|
||||
EXPECT_LE(_200kb + _500mb, r.total_allocated);
|
||||
EXPECT_LE(_200kb + _500mb, r.total_reserved);
|
||||
EXPECT_TRUE(r.device.is_xpu());
|
||||
auto alloc2_true_ptr = r.ptr;
|
||||
auto alloc2_true_alloc_size = r.alloc_size;
|
||||
|
||||
auto max_reserved = r.total_reserved;
|
||||
|
||||
alloc1.clear();
|
||||
r = reporter->getLatestRecord();
|
||||
EXPECT_EQ(alloc1_true_ptr, r.ptr);
|
||||
EXPECT_EQ(-alloc1_true_alloc_size, r.alloc_size);
|
||||
EXPECT_EQ(alloc2_true_alloc_size, r.total_allocated);
|
||||
// alloc2 remain, it is a memory free operation, so it shouldn't reserve more
|
||||
// memory.
|
||||
EXPECT_TRUE(
|
||||
alloc2_true_alloc_size <= static_cast<int64_t>(r.total_reserved) &&
|
||||
r.total_reserved <= max_reserved);
|
||||
EXPECT_TRUE(r.device.is_xpu());
|
||||
|
||||
alloc2.clear();
|
||||
r = reporter->getLatestRecord();
|
||||
EXPECT_EQ(alloc2_true_ptr, r.ptr);
|
||||
EXPECT_EQ(-alloc2_true_alloc_size, r.alloc_size);
|
||||
EXPECT_EQ(0, r.total_allocated);
|
||||
EXPECT_TRUE(r.total_reserved <= max_reserved);
|
||||
EXPECT_TRUE(r.device.is_xpu());
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
c10::xpu::XPUCachingAllocator::init(1);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -386,6 +386,13 @@ class DeviceCachingAllocator {
|
||||
stats.requested_bytes[stat_type].increase(block->requested_size);
|
||||
});
|
||||
|
||||
c10::reportMemoryUsageToProfiler(
|
||||
block->ptr,
|
||||
static_cast<int64_t>(block->size),
|
||||
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
c10::Device(c10::DeviceType::XPU, device));
|
||||
|
||||
return block;
|
||||
}
|
||||
|
||||
@ -431,6 +438,13 @@ class DeviceCachingAllocator {
|
||||
auto reserved_bytes =
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
|
||||
c10::reportOutOfMemoryToProfiler(
|
||||
static_cast<int64_t>(size),
|
||||
allocated_bytes,
|
||||
reserved_bytes,
|
||||
c10::Device(c10::DeviceType::XPU, device));
|
||||
|
||||
TORCH_CHECK_WITH(
|
||||
OutOfMemoryError,
|
||||
false,
|
||||
@ -455,6 +469,9 @@ class DeviceCachingAllocator {
|
||||
std::scoped_lock<std::recursive_mutex> lock(mutex);
|
||||
block->allocated = false;
|
||||
|
||||
auto orig_block_ptr = block->ptr;
|
||||
auto orig_block_size = block->size;
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.allocated_bytes[stat_type].decrease(block->size);
|
||||
@ -465,6 +482,13 @@ class DeviceCachingAllocator {
|
||||
} else {
|
||||
free_block(block);
|
||||
}
|
||||
|
||||
c10::reportMemoryUsageToProfiler(
|
||||
orig_block_ptr,
|
||||
-static_cast<int64_t>(orig_block_size),
|
||||
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
|
||||
c10::Device(c10::DeviceType::XPU, block->device));
|
||||
}
|
||||
|
||||
void recordStream(Block* block, xpu::XPUStream stream) {
|
||||
|
@ -605,6 +605,9 @@ class TestProfiler(TestCase):
|
||||
def create_cuda_tensor():
|
||||
return torch.rand(10, 10).cuda()
|
||||
|
||||
def create_xpu_tensor():
|
||||
return torch.rand(10, 10).xpu()
|
||||
|
||||
def create_mkldnn_tensor():
|
||||
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()
|
||||
|
||||
@ -675,6 +678,30 @@ class TestProfiler(TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
if torch.xpu.is_available():
|
||||
create_xpu_tensor()
|
||||
stats = run_profiler(create_xpu_tensor)
|
||||
check_metrics(
|
||||
stats,
|
||||
"device_memory_usage",
|
||||
allocs=[
|
||||
"test_user_scope_alloc",
|
||||
"aten::to",
|
||||
"aten::empty_strided",
|
||||
],
|
||||
deallocs=[
|
||||
"test_user_scope_dealloc",
|
||||
],
|
||||
)
|
||||
check_metrics(
|
||||
stats,
|
||||
"cpu_memory_usage",
|
||||
allocs=[
|
||||
"aten::rand",
|
||||
"aten::empty",
|
||||
],
|
||||
)
|
||||
|
||||
if torch.backends.mkldnn.is_available():
|
||||
create_mkldnn_tensor()
|
||||
stats = run_profiler(create_mkldnn_tensor)
|
||||
@ -699,6 +726,9 @@ class TestProfiler(TestCase):
|
||||
if torch.cuda.is_available():
|
||||
y = torch.rand(10, 10).cuda()
|
||||
del y
|
||||
elif torch.xpu.is_available():
|
||||
y = torch.rand(10, 10).to("xpu")
|
||||
del y
|
||||
gc.collect()
|
||||
stats = prof.key_averages(group_by_input_shape=True)
|
||||
check_metrics(
|
||||
@ -709,6 +739,8 @@ class TestProfiler(TestCase):
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
|
||||
elif torch.xpu.is_available():
|
||||
check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"
|
||||
|
@ -563,7 +563,12 @@ class profile:
|
||||
return (
|
||||
mem_record.nbytes()
|
||||
if mem_record.device_type()
|
||||
in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP]
|
||||
in [
|
||||
DeviceType.CUDA,
|
||||
DeviceType.PrivateUse1,
|
||||
DeviceType.HIP,
|
||||
DeviceType.XPU,
|
||||
]
|
||||
else 0
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user