Files
pytorch/torch/distributed/examples/memory_tracker_example.py
Frost Mitchell db01f1032f Support XPU in memory tracker (#150703)
This PR adds support for XPU devices to the distributed MemoryTracker tool, including unit test for XPU.

Specifically, this code adds tracking for a few alloc-related statistics for XPUCachingAllocator. It also adapts the existing memory tracker tool to be device agnostic, by getting the device module and recording the necessary memory stats. (I get the device module instead of using `torch.accelerator` methods, as that API is still in-progress.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150703
Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/gujinghui, https://github.com/d4l3k
2025-06-12 21:33:52 +00:00

40 lines
1007 B
Python

# mypy: allow-untyped-defs
import torch
from torch.distributed._tools import MemoryTracker
def run_one_model(net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"):
net.to(device)
input = input.to(device)
# Create the memory Tracker
mem_tracker = MemoryTracker()
# start_monitor before the training iteration starts
mem_tracker.start_monitor(net)
# run one training iteration
net.zero_grad(True)
loss = net(input)
if isinstance(loss, dict):
loss = loss["out"]
loss.sum().backward()
net.zero_grad(set_to_none=True)
# stop monitoring after the training iteration ends
mem_tracker.stop()
# print the memory stats summary
mem_tracker.summary()
# plot the memory traces at operator level
mem_tracker.show_traces()
if __name__ == "__main__":
import torchvision
dev = "cuda"
run_one_model(
torchvision.models.resnet34(),
torch.rand(32, 3, 224, 224, device=dev),
device=dev,
)