Support XPU in memory tracker (#150703)

This PR adds support for XPU devices to the distributed MemoryTracker tool, including unit test for XPU.

Specifically, this code adds tracking for a few alloc-related statistics for XPUCachingAllocator. It also adapts the existing memory tracker tool to be device agnostic, by getting the device module and recording the necessary memory stats. (I get the device module instead of using `torch.accelerator` methods, as that API is still in-progress.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150703
Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/gujinghui, https://github.com/d4l3k
This commit is contained in:
Frost Mitchell
2025-06-12 21:33:52 +00:00
committed by PyTorch MergeBot
parent 154a39bfbd
commit db01f1032f
4 changed files with 38 additions and 26 deletions

View File

@ -251,9 +251,12 @@ class DeviceCachingAllocator {
return true;
}
bool alloc_block(AllocParams& p) {
bool alloc_block(AllocParams& p, bool isRetry) {
auto size = p.alloc_size;
auto device = p.device();
if (isRetry) {
stats.num_alloc_retries += 1;
}
void* ptr = sycl::aligned_alloc_device(
kDeviceAlignment,
size,
@ -425,8 +428,8 @@ class DeviceCachingAllocator {
bool block_found = get_free_block(params);
// Can't reuse an existing block, try to get a new one.
if (!block_found) {
block_found = alloc_block(params) ||
(release_cached_blocks() && alloc_block(params));
block_found = alloc_block(params, false) ||
(release_cached_blocks() && alloc_block(params, true));
}
if (!block_found) {
c10::xpu::DeviceProp device_prop;
@ -519,6 +522,7 @@ class DeviceCachingAllocator {
stats.active_bytes[statType].reset_accumulated();
stats.requested_bytes[statType].reset_accumulated();
}
stats.num_alloc_retries = 0;
}
void resetPeakStats() {

View File

@ -6,17 +6,18 @@ import torch
import torch.nn as nn
from torch.distributed._tools import MemoryTracker
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
class TestMemoryTracker(TestCase):
@unittest.skipIf(not TEST_CUDA, "no cuda")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
def test_local_model(self):
"""
Minimal test case to check the memory tracker can collect the expected
memory stats at operator level, as well as can print the summary result
without crash.
"""
device = "cuda" if TEST_CUDA else "xpu"
# Create a model with a hierarchy of modules
torch.manual_seed(0)
model = nn.Sequential(
@ -28,16 +29,16 @@ class TestMemoryTracker(TestCase):
),
nn.Flatten(start_dim=1),
nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)),
).cuda()
).to(device)
# Run one iteration of forward and backward pass
tracker = MemoryTracker()
tracker.start_monitor(model)
x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda"))
# torch.LongTensor expects cpu device type, not cuda device type in
# constructor, so calling .cuda() outside constructor here.
target = torch.LongTensor([0, 1]).cuda()
x = torch.randn(size=(2, 3, 224, 224), device=torch.device(device))
# torch.LongTensor expects cpu device type, not device type in
# constructor, so calling .to(device) outside constructor here.
target = torch.LongTensor([0, 1]).to(device)
criterion = nn.CrossEntropyLoss()
criterion(model(x), target).backward()
@ -61,7 +62,7 @@ class TestMemoryTracker(TestCase):
self.assertEqual(len(tracker.memories_reserved), tracker._op_index)
self.assertTrue(len(tracker._markers) == 2)
self.assertTrue(tracker._cur_module_name != "")
self.assertTrue(hasattr(tracker, "_num_cuda_retries"))
self.assertTrue(hasattr(tracker, "_num_alloc_retries"))
if __name__ == "__main__":

View File

@ -81,7 +81,8 @@ class MemoryTracker:
self._markers: dict[str, int] = defaultdict(int)
self._cur_module_name: str = ""
self._op_index: int = 0
self._num_cuda_retries: int = 0
self._num_alloc_retries: int = 0
self._device_module = torch.get_device_module()
@no_type_check
def start_monitor(self, root_module: nn.Module) -> None:
@ -106,7 +107,7 @@ class MemoryTracker:
# clear and remove it for now as it does not really capture important info.
# h3 = m.register_backward_hook(self._create_backward_hook(name))
self._hooks.extend([h1, h2])
torch.cuda.empty_cache()
self._device_module.empty_cache()
assert getattr(self, "profile_mode", None) is None
self.profile_mode = MemoryProfileDispatchMode(self)
self.profile_mode.__enter__()
@ -116,9 +117,11 @@ class MemoryTracker:
"""
Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level.
Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``.
Get some aggregated stats when the memory_tracker() is enabled, like ``num_alloc_retries``.
"""
self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0)
self._num_alloc_retries = self._device_module.memory_stats().get(
"num_alloc_retries", 0
)
for h in self._hooks:
h.remove()
@ -142,7 +145,7 @@ class MemoryTracker:
previous_allocated_memory = current_allocated_memory
print("------------------------------------------------")
print(f"The number of cuda retries are: {self._num_cuda_retries}")
print(f"The number of alloc retries are: {self._num_alloc_retries}")
print(f"Top {top} ops that generates memory are:")
for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[
:top
@ -206,7 +209,7 @@ class MemoryTracker:
"memories_active": self.memories_active,
"memories_reserved": self.memories_reserved,
"markers": self._markers,
"num_alloc_retries": self._num_cuda_retries,
"num_alloc_retries": self._num_alloc_retries,
}
with open(path, "wb") as f:
@ -221,7 +224,7 @@ class MemoryTracker:
self.memories_active = stats["memories_active"]
self.memories_reserved = stats["memories_reserved"]
self._markers = stats["markers"]
self._num_cuda_retries = stats["num_alloc_retries"]
self._num_alloc_retries = stats["num_alloc_retries"]
def _create_pre_forward_hook(self, name: str) -> Callable:
"""Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start."""
@ -269,10 +272,11 @@ class MemoryTracker:
The memory stats dict is indexed with ``self._op_index``.
"""
memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB
memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB
memory_allocated: float = self._device_module.memory_allocated() / BYTES_PER_MB
memory_reserved: float = self._device_module.memory_reserved() / BYTES_PER_MB
memory_active: float = (
torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB
self._device_module.memory_stats().get("active_bytes.all.current", 0)
/ BYTES_PER_MB
)
self.memories_allocated[self._op_index] = (fn_name, memory_allocated)
self.memories_reserved[self._op_index] = (fn_name, memory_reserved)
@ -293,4 +297,4 @@ class MemoryTracker:
self._markers.clear()
self._cur_module_name = ""
self._op_index = 0
self._num_cuda_retries = 0
self._num_alloc_retries = 0

View File

@ -3,9 +3,9 @@ import torch
from torch.distributed._tools import MemoryTracker
def run_one_model(net: torch.nn.Module, input: torch.Tensor):
net.cuda()
input = input.cuda()
def run_one_model(net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"):
net.to(device)
input = input.to(device)
# Create the memory Tracker
mem_tracker = MemoryTracker()
@ -31,6 +31,9 @@ def run_one_model(net: torch.nn.Module, input: torch.Tensor):
if __name__ == "__main__":
import torchvision
dev = "cuda"
run_one_model(
torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda")
torchvision.models.resnet34(),
torch.rand(32, 3, 224, 224, device=dev),
device=dev,
)