Add memory reporting for XPU to Memory Profiler (#152842)

Adds support for XPU profile_memory in Pytorch Profiler.

Currently, when `profile_memory=True` is passed to `torch.profiler.profile`, there is no XPU memory reported. For example, the profiling table printed by the code below is missing any `XPU Mem` columns:

<details><summary>profiling.py</summary>
<p>

```python
import torch
import torch.nn as nn
import torch.optim as optim

from torch.profiler import profile, ProfilerActivity

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.conv1 = nn.Conv1d(20,20,15,padding="same")
        self.flatten = nn.Flatten()
        self.net1 = nn.Linear(2048, 4096)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(4096, 5)

    def forward(self, x):
        res = self.conv1(x)
        res = self.flatten(res)
        res = self.net1(res)
        return self.net2(self.relu(res))

def demo_basic():
    model = ToyModel().to("xpu")
    loss_fn = nn.MSELoss().to("xpu")
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], profile_memory=True) as prof:
        for epoch in range(10):
            optimizer.zero_grad()
            outputs = model(torch.randn(20, 2048).to("xpu"))
            labels = torch.randn(20, 5).to("xpu")
            loss_fn(outputs, labels).backward()
            optimizer.step()
    print(prof.key_averages().table(max_name_column_width=100, sort_by="xpu_time_total", row_limit=100))

if __name__ == "__main__":
    demo_basic()
```
</p>
</details>

```
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg       CPU Mem  Self CPU Mem    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                            gemm_kernel         0.00%       0.000us         0.00%       0.000us       0.000us       1.501ms        44.73%       1.501ms      25.024us           0 b           0 b            60
    autograd::engine::evaluate_function: AddmmBackward0         0.12%       1.067ms        30.47%     260.929ms      13.046ms       0.000us         0.00%       1.009ms      50.448us           0 b           0 b            20
                                         AddmmBackward0         0.09%     744.983us        15.99%     136.944ms       6.847ms       0.000us         0.00%     784.640us      39.232us           0 b           0 b            20
                                               aten::mm        15.41%     131.956ms        15.79%     135.167ms       3.379ms     784.640us        23.37%     784.640us      19.616us           0 b           0 b            40
                                           aten::linear         0.02%     156.361us        20.58%     176.187ms       8.809ms       0.000us         0.00%     741.760us      37.088us           0 b           0 b            20
                                            aten::addmm        20.25%     173.371ms        20.52%     175.723ms       8.786ms     741.760us        22.10%     741.760us      37.088us           0 b           0 b            20
                                Optimizer.step#SGD.step         0.40%       3.429ms         5.55%      47.509ms       4.751ms       0.000us         0.00%     488.960us      48.896us           0 b           0 b            10
                                    aten::_foreach_add_         4.81%      41.162ms         5.15%      44.080ms       4.408ms     488.960us        14.57%     488.960us      48.896us           0 b           0 b            10
at::native::xpu::MultiTensorApplyKernelFunctor<at::n...         0.00%       0.000us         0.00%       0.000us       0.000us     422.880us        12.60%     422.880us      42.288us           0 b           0 b            10
autograd::engine::evaluate_function: ConvolutionBack...         0.03%     280.041us         4.36%      37.328ms       3.733ms       0.000us         0.00%     356.320us      35.632us           0 b           0 b            10
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 856.227ms
Self XPU time total: 3.357ms
```

This PR updates the XPUCachingAllocator.cpp to report allocation events to the Profiler, and causes these to be printed in the table:
```
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg       CPU Mem  Self CPU Mem       XPU Mem  Self XPU Mem    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                            gemm_kernel         0.00%       0.000us         0.00%       0.000us       0.000us       1.436ms        43.64%       1.436ms      23.939us           0 b           0 b           0 b           0 b            60
    autograd::engine::evaluate_function: AddmmBackward0         0.13%       1.186ms        29.92%     262.875ms      13.144ms       0.000us         0.00%       1.005ms      50.272us           0 b           0 b     320.94 Mb      -4.69 Mb            20
                                         AddmmBackward0         0.09%     815.288us        16.48%     144.802ms       7.240ms       0.000us         0.00%     790.720us      39.536us           0 b           0 b     325.47 Mb           0 b            20
                                               aten::mm        15.86%     139.342ms        16.26%     142.875ms       3.572ms     790.720us        24.03%     790.720us      19.768us           0 b           0 b     325.47 Mb     325.47 Mb            40
                                           aten::linear         0.02%     182.856us        20.46%     179.775ms       8.989ms       0.000us         0.00%     669.440us      33.472us           0 b           0 b       3.13 Mb           0 b            20
                                            aten::addmm        20.10%     176.607ms        20.40%     179.210ms       8.961ms     669.440us        20.34%     669.440us      33.472us           0 b           0 b       3.13 Mb       3.13 Mb            20
                                Optimizer.step#SGD.step         0.42%       3.692ms         5.61%      49.267ms       4.927ms       0.000us         0.00%     486.640us      48.664us           0 b           0 b           0 b           0 b            10
                                    aten::_foreach_add_         4.83%      42.439ms         5.19%      45.574ms       4.557ms     486.640us        14.79%     486.640us      48.664us           0 b           0 b           0 b     -20.00 Kb            10
at::native::xpu::MultiTensorApplyKernelFunctor<at::n...         0.00%       0.000us         0.00%       0.000us       0.000us     420.960us        12.79%     420.960us      42.096us           0 b           0 b           0 b           0 b            10
autograd::engine::evaluate_function: ConvolutionBack...         0.04%     310.719us         4.47%      39.279ms       3.928ms       0.000us         0.00%     339.520us      33.952us           0 b           0 b      -2.89 Mb      -3.12 Mb            10
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 878.627ms
Self XPU time total: 3.291ms
```

These XPU memory numbers match the same profiling results on CUDA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152842
Approved by: https://github.com/guangyey, https://github.com/sraikund16
This commit is contained in:
Frost Mitchell
2025-05-21 01:19:19 +00:00
committed by PyTorch MergeBot
parent 8817e5ac80
commit fe49b11e09
5 changed files with 132 additions and 1 deletions

View File

@ -122,6 +122,7 @@ list(APPEND ATen_XPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/xpu_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xpu_event_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xpu_generator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xpu_reportMemoryUsage_test.cpp
)
# ---[ Send the lists to the parent scope.

View File

@ -0,0 +1,69 @@
#include <ATen/test/reportMemoryUsage.h>
#include <gtest/gtest.h>
#include <c10/xpu/XPUCachingAllocator.h>
TEST(DeviceCachingAllocator, check_reporter) {
auto reporter = std::make_shared<TestMemoryReportingInfo>();
c10::DebugInfoGuard guard(c10::DebugInfoKind::PROFILER_STATE, reporter);
auto _200kb = 200 * 1024;
auto _500mb = 500 * 1024 * 1024;
auto allocator = c10::xpu::XPUCachingAllocator::get();
auto alloc1 = allocator->allocate(_200kb);
auto r = reporter->getLatestRecord();
EXPECT_EQ(alloc1.get(), r.ptr);
EXPECT_LE(_200kb, r.alloc_size);
EXPECT_LE(_200kb, r.total_allocated);
EXPECT_LE(_200kb, r.total_reserved);
EXPECT_TRUE(r.device.is_xpu());
auto alloc1_true_ptr = r.ptr;
auto alloc1_true_alloc_size = r.alloc_size;
// I bet pytorch will not waste that much memory
EXPECT_LT(r.total_allocated, 2 * _200kb);
// I bet pytorch will not reserve that much memory
EXPECT_LT(r.total_reserved, _500mb);
auto alloc2 = allocator->allocate(_500mb);
r = reporter->getLatestRecord();
EXPECT_EQ(alloc2.get(), r.ptr);
EXPECT_LE(_500mb, r.alloc_size);
EXPECT_LE(_200kb + _500mb, r.total_allocated);
EXPECT_LE(_200kb + _500mb, r.total_reserved);
EXPECT_TRUE(r.device.is_xpu());
auto alloc2_true_ptr = r.ptr;
auto alloc2_true_alloc_size = r.alloc_size;
auto max_reserved = r.total_reserved;
alloc1.clear();
r = reporter->getLatestRecord();
EXPECT_EQ(alloc1_true_ptr, r.ptr);
EXPECT_EQ(-alloc1_true_alloc_size, r.alloc_size);
EXPECT_EQ(alloc2_true_alloc_size, r.total_allocated);
// alloc2 remain, it is a memory free operation, so it shouldn't reserve more
// memory.
EXPECT_TRUE(
alloc2_true_alloc_size <= static_cast<int64_t>(r.total_reserved) &&
r.total_reserved <= max_reserved);
EXPECT_TRUE(r.device.is_xpu());
alloc2.clear();
r = reporter->getLatestRecord();
EXPECT_EQ(alloc2_true_ptr, r.ptr);
EXPECT_EQ(-alloc2_true_alloc_size, r.alloc_size);
EXPECT_EQ(0, r.total_allocated);
EXPECT_TRUE(r.total_reserved <= max_reserved);
EXPECT_TRUE(r.device.is_xpu());
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
c10::xpu::XPUCachingAllocator::init(1);
return RUN_ALL_TESTS();
}

View File

@ -386,6 +386,13 @@ class DeviceCachingAllocator {
stats.requested_bytes[stat_type].increase(block->requested_size);
});
c10::reportMemoryUsageToProfiler(
block->ptr,
static_cast<int64_t>(block->size),
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::XPU, device));
return block;
}
@ -431,6 +438,13 @@ class DeviceCachingAllocator {
auto reserved_bytes =
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
c10::reportOutOfMemoryToProfiler(
static_cast<int64_t>(size),
allocated_bytes,
reserved_bytes,
c10::Device(c10::DeviceType::XPU, device));
TORCH_CHECK_WITH(
OutOfMemoryError,
false,
@ -455,6 +469,9 @@ class DeviceCachingAllocator {
std::scoped_lock<std::recursive_mutex> lock(mutex);
block->allocated = false;
auto orig_block_ptr = block->ptr;
auto orig_block_size = block->size;
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.allocated_bytes[stat_type].decrease(block->size);
@ -465,6 +482,13 @@ class DeviceCachingAllocator {
} else {
free_block(block);
}
c10::reportMemoryUsageToProfiler(
orig_block_ptr,
-static_cast<int64_t>(orig_block_size),
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::XPU, block->device));
}
void recordStream(Block* block, xpu::XPUStream stream) {

View File

@ -605,6 +605,9 @@ class TestProfiler(TestCase):
def create_cuda_tensor():
return torch.rand(10, 10).cuda()
def create_xpu_tensor():
return torch.rand(10, 10).xpu()
def create_mkldnn_tensor():
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()
@ -675,6 +678,30 @@ class TestProfiler(TestCase):
],
)
if torch.xpu.is_available():
create_xpu_tensor()
stats = run_profiler(create_xpu_tensor)
check_metrics(
stats,
"device_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::to",
"aten::empty_strided",
],
deallocs=[
"test_user_scope_dealloc",
],
)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
],
)
if torch.backends.mkldnn.is_available():
create_mkldnn_tensor()
stats = run_profiler(create_mkldnn_tensor)
@ -699,6 +726,9 @@ class TestProfiler(TestCase):
if torch.cuda.is_available():
y = torch.rand(10, 10).cuda()
del y
elif torch.xpu.is_available():
y = torch.rand(10, 10).to("xpu")
del y
gc.collect()
stats = prof.key_averages(group_by_input_shape=True)
check_metrics(
@ -709,6 +739,8 @@ class TestProfiler(TestCase):
)
if torch.cuda.is_available():
check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
elif torch.xpu.is_available():
check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
@unittest.skipIf(
IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"

View File

@ -563,7 +563,12 @@ class profile:
return (
mem_record.nbytes()
if mem_record.device_type()
in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP]
in [
DeviceType.CUDA,
DeviceType.PrivateUse1,
DeviceType.HIP,
DeviceType.XPU,
]
else 0
)