mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This PR adds support for XPU devices to the distributed MemoryTracker tool, including unit test for XPU. Specifically, this code adds tracking for a few alloc-related statistics for XPUCachingAllocator. It also adapts the existing memory tracker tool to be device agnostic, by getting the device module and recording the necessary memory stats. (I get the device module instead of using `torch.accelerator` methods, as that API is still in-progress.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150703 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/gujinghui, https://github.com/d4l3k
40 lines
1007 B
Python
40 lines
1007 B
Python
# mypy: allow-untyped-defs
|
|
import torch
|
|
from torch.distributed._tools import MemoryTracker
|
|
|
|
|
|
def run_one_model(net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"):
|
|
net.to(device)
|
|
input = input.to(device)
|
|
|
|
# Create the memory Tracker
|
|
mem_tracker = MemoryTracker()
|
|
# start_monitor before the training iteration starts
|
|
mem_tracker.start_monitor(net)
|
|
|
|
# run one training iteration
|
|
net.zero_grad(True)
|
|
loss = net(input)
|
|
if isinstance(loss, dict):
|
|
loss = loss["out"]
|
|
loss.sum().backward()
|
|
net.zero_grad(set_to_none=True)
|
|
|
|
# stop monitoring after the training iteration ends
|
|
mem_tracker.stop()
|
|
# print the memory stats summary
|
|
mem_tracker.summary()
|
|
# plot the memory traces at operator level
|
|
mem_tracker.show_traces()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import torchvision
|
|
|
|
dev = "cuda"
|
|
run_one_model(
|
|
torchvision.models.resnet34(),
|
|
torch.rand(32, 3, 224, 224, device=dev),
|
|
device=dev,
|
|
)
|