mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support XPU in memory tracker (#150703)
This PR adds support for XPU devices to the distributed MemoryTracker tool, including unit test for XPU. Specifically, this code adds tracking for a few alloc-related statistics for XPUCachingAllocator. It also adapts the existing memory tracker tool to be device agnostic, by getting the device module and recording the necessary memory stats. (I get the device module instead of using `torch.accelerator` methods, as that API is still in-progress.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150703 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/gujinghui, https://github.com/d4l3k
This commit is contained in:
committed by
PyTorch MergeBot
parent
154a39bfbd
commit
db01f1032f
@ -251,9 +251,12 @@ class DeviceCachingAllocator {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool alloc_block(AllocParams& p) {
|
||||
bool alloc_block(AllocParams& p, bool isRetry) {
|
||||
auto size = p.alloc_size;
|
||||
auto device = p.device();
|
||||
if (isRetry) {
|
||||
stats.num_alloc_retries += 1;
|
||||
}
|
||||
void* ptr = sycl::aligned_alloc_device(
|
||||
kDeviceAlignment,
|
||||
size,
|
||||
@ -425,8 +428,8 @@ class DeviceCachingAllocator {
|
||||
bool block_found = get_free_block(params);
|
||||
// Can't reuse an existing block, try to get a new one.
|
||||
if (!block_found) {
|
||||
block_found = alloc_block(params) ||
|
||||
(release_cached_blocks() && alloc_block(params));
|
||||
block_found = alloc_block(params, false) ||
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
@ -519,6 +522,7 @@ class DeviceCachingAllocator {
|
||||
stats.active_bytes[statType].reset_accumulated();
|
||||
stats.requested_bytes[statType].reset_accumulated();
|
||||
}
|
||||
stats.num_alloc_retries = 0;
|
||||
}
|
||||
|
||||
void resetPeakStats() {
|
||||
|
@ -6,17 +6,18 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tools import MemoryTracker
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
|
||||
|
||||
|
||||
class TestMemoryTracker(TestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
|
||||
def test_local_model(self):
|
||||
"""
|
||||
Minimal test case to check the memory tracker can collect the expected
|
||||
memory stats at operator level, as well as can print the summary result
|
||||
without crash.
|
||||
"""
|
||||
device = "cuda" if TEST_CUDA else "xpu"
|
||||
# Create a model with a hierarchy of modules
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(
|
||||
@ -28,16 +29,16 @@ class TestMemoryTracker(TestCase):
|
||||
),
|
||||
nn.Flatten(start_dim=1),
|
||||
nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)),
|
||||
).cuda()
|
||||
).to(device)
|
||||
|
||||
# Run one iteration of forward and backward pass
|
||||
tracker = MemoryTracker()
|
||||
tracker.start_monitor(model)
|
||||
|
||||
x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda"))
|
||||
# torch.LongTensor expects cpu device type, not cuda device type in
|
||||
# constructor, so calling .cuda() outside constructor here.
|
||||
target = torch.LongTensor([0, 1]).cuda()
|
||||
x = torch.randn(size=(2, 3, 224, 224), device=torch.device(device))
|
||||
# torch.LongTensor expects cpu device type, not device type in
|
||||
# constructor, so calling .to(device) outside constructor here.
|
||||
target = torch.LongTensor([0, 1]).to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion(model(x), target).backward()
|
||||
|
||||
@ -61,7 +62,7 @@ class TestMemoryTracker(TestCase):
|
||||
self.assertEqual(len(tracker.memories_reserved), tracker._op_index)
|
||||
self.assertTrue(len(tracker._markers) == 2)
|
||||
self.assertTrue(tracker._cur_module_name != "")
|
||||
self.assertTrue(hasattr(tracker, "_num_cuda_retries"))
|
||||
self.assertTrue(hasattr(tracker, "_num_alloc_retries"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -81,7 +81,8 @@ class MemoryTracker:
|
||||
self._markers: dict[str, int] = defaultdict(int)
|
||||
self._cur_module_name: str = ""
|
||||
self._op_index: int = 0
|
||||
self._num_cuda_retries: int = 0
|
||||
self._num_alloc_retries: int = 0
|
||||
self._device_module = torch.get_device_module()
|
||||
|
||||
@no_type_check
|
||||
def start_monitor(self, root_module: nn.Module) -> None:
|
||||
@ -106,7 +107,7 @@ class MemoryTracker:
|
||||
# clear and remove it for now as it does not really capture important info.
|
||||
# h3 = m.register_backward_hook(self._create_backward_hook(name))
|
||||
self._hooks.extend([h1, h2])
|
||||
torch.cuda.empty_cache()
|
||||
self._device_module.empty_cache()
|
||||
assert getattr(self, "profile_mode", None) is None
|
||||
self.profile_mode = MemoryProfileDispatchMode(self)
|
||||
self.profile_mode.__enter__()
|
||||
@ -116,9 +117,11 @@ class MemoryTracker:
|
||||
"""
|
||||
Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level.
|
||||
|
||||
Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``.
|
||||
Get some aggregated stats when the memory_tracker() is enabled, like ``num_alloc_retries``.
|
||||
"""
|
||||
self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0)
|
||||
self._num_alloc_retries = self._device_module.memory_stats().get(
|
||||
"num_alloc_retries", 0
|
||||
)
|
||||
|
||||
for h in self._hooks:
|
||||
h.remove()
|
||||
@ -142,7 +145,7 @@ class MemoryTracker:
|
||||
previous_allocated_memory = current_allocated_memory
|
||||
|
||||
print("------------------------------------------------")
|
||||
print(f"The number of cuda retries are: {self._num_cuda_retries}")
|
||||
print(f"The number of alloc retries are: {self._num_alloc_retries}")
|
||||
print(f"Top {top} ops that generates memory are:")
|
||||
for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[
|
||||
:top
|
||||
@ -206,7 +209,7 @@ class MemoryTracker:
|
||||
"memories_active": self.memories_active,
|
||||
"memories_reserved": self.memories_reserved,
|
||||
"markers": self._markers,
|
||||
"num_alloc_retries": self._num_cuda_retries,
|
||||
"num_alloc_retries": self._num_alloc_retries,
|
||||
}
|
||||
|
||||
with open(path, "wb") as f:
|
||||
@ -221,7 +224,7 @@ class MemoryTracker:
|
||||
self.memories_active = stats["memories_active"]
|
||||
self.memories_reserved = stats["memories_reserved"]
|
||||
self._markers = stats["markers"]
|
||||
self._num_cuda_retries = stats["num_alloc_retries"]
|
||||
self._num_alloc_retries = stats["num_alloc_retries"]
|
||||
|
||||
def _create_pre_forward_hook(self, name: str) -> Callable:
|
||||
"""Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start."""
|
||||
@ -269,10 +272,11 @@ class MemoryTracker:
|
||||
|
||||
The memory stats dict is indexed with ``self._op_index``.
|
||||
"""
|
||||
memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB
|
||||
memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB
|
||||
memory_allocated: float = self._device_module.memory_allocated() / BYTES_PER_MB
|
||||
memory_reserved: float = self._device_module.memory_reserved() / BYTES_PER_MB
|
||||
memory_active: float = (
|
||||
torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB
|
||||
self._device_module.memory_stats().get("active_bytes.all.current", 0)
|
||||
/ BYTES_PER_MB
|
||||
)
|
||||
self.memories_allocated[self._op_index] = (fn_name, memory_allocated)
|
||||
self.memories_reserved[self._op_index] = (fn_name, memory_reserved)
|
||||
@ -293,4 +297,4 @@ class MemoryTracker:
|
||||
self._markers.clear()
|
||||
self._cur_module_name = ""
|
||||
self._op_index = 0
|
||||
self._num_cuda_retries = 0
|
||||
self._num_alloc_retries = 0
|
||||
|
@ -3,9 +3,9 @@ import torch
|
||||
from torch.distributed._tools import MemoryTracker
|
||||
|
||||
|
||||
def run_one_model(net: torch.nn.Module, input: torch.Tensor):
|
||||
net.cuda()
|
||||
input = input.cuda()
|
||||
def run_one_model(net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"):
|
||||
net.to(device)
|
||||
input = input.to(device)
|
||||
|
||||
# Create the memory Tracker
|
||||
mem_tracker = MemoryTracker()
|
||||
@ -31,6 +31,9 @@ def run_one_model(net: torch.nn.Module, input: torch.Tensor):
|
||||
if __name__ == "__main__":
|
||||
import torchvision
|
||||
|
||||
dev = "cuda"
|
||||
run_one_model(
|
||||
torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda")
|
||||
torchvision.models.resnet34(),
|
||||
torch.rand(32, 3, 224, 224, device=dev),
|
||||
device=dev,
|
||||
)
|
||||
|
Reference in New Issue
Block a user