mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Profiler] Memory profiler part 1: Gradient identification (#86802)
There are multiple ways to indentify that a Tensor is a gradient. (A subset of which also give additional context.) So to start off I've made a utility to handle that determination. Differential Revision: [D39920730](https://our.internmc.facebook.com/intern/diff/D39920730/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/86802 Approved by: https://github.com/chaekit
This commit is contained in:
committed by
PyTorch MergeBot
parent
c0e6b4329f
commit
cef13ebea0
@ -40,6 +40,7 @@ files =
|
||||
.github,
|
||||
benchmarks/instruction_counts,
|
||||
tools,
|
||||
torch/profiler/_memory_profiler.py,
|
||||
torch/utils/_pytree.py,
|
||||
torch/utils/benchmark/utils/common.py,
|
||||
torch/utils/benchmark/utils/timer.py,
|
||||
|
224
test/profiler/test_memory_profiler.py
Normal file
224
test/profiler/test_memory_profiler.py
Normal file
@ -0,0 +1,224 @@
|
||||
# Owner(s): ["oncall: profiler"]
|
||||
import functools
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import torch
|
||||
from torch._C._profiler import _EventType
|
||||
from torch.profiler import _memory_profiler, _utils
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
profile = functools.partial(
|
||||
torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
|
||||
)
|
||||
|
||||
|
||||
class ScaleLayer(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * self.scale
|
||||
|
||||
|
||||
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
|
||||
class TestIdentifyGradients(TestCase):
|
||||
def gradient_detected(
|
||||
self,
|
||||
prof: torch.profiler.profile,
|
||||
ctx: _EventType,
|
||||
grad_tensor: torch.Tensor,
|
||||
parameter: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
# This is not an exhaustive check, but for the purpose of unit testing
|
||||
# it is sufficient.
|
||||
def key_matches_tensor(key, tensor) -> bool:
|
||||
# Vacuous case.
|
||||
if tensor is None:
|
||||
return True
|
||||
|
||||
if key is None:
|
||||
return False
|
||||
|
||||
return tensor.storage().data_ptr() == key.storage.ptr
|
||||
|
||||
tree = prof.profiler.kineto_results.experimental_event_tree()
|
||||
for node in _utils.traverse_dfs(tree):
|
||||
for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
|
||||
if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
|
||||
if parameter is None:
|
||||
return True # Don't need to check parameter; we're done.
|
||||
|
||||
elif p_key is not None:
|
||||
# For a complex workflow a gradient could correspond to
|
||||
# different parameters at different points in a trace.
|
||||
# However this will not happen in the relatively simple
|
||||
# cases tested here, so if `extract_gradients` identifies
|
||||
# the parameter corresponding to a particular gradient it
|
||||
# must be the one we expect.
|
||||
self.assertTrue(key_matches_tensor(p_key, parameter))
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
|
||||
self.assertTrue(
|
||||
self.gradient_detected(*args, **kwargs),
|
||||
f"Failed to identify gradient `{name}` from profile.",
|
||||
)
|
||||
|
||||
def assertOnlyGradients(
|
||||
self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
|
||||
) -> None:
|
||||
allowed_set = {t.storage().data_ptr() for t in tensors}
|
||||
|
||||
tree = prof.profiler.kineto_results.experimental_event_tree()
|
||||
for node in _utils.traverse_dfs(tree):
|
||||
for _, p_grad_key in _memory_profiler.extract_gradients(node):
|
||||
self.assertTrue(
|
||||
p_grad_key.storage.ptr in allowed_set,
|
||||
f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
|
||||
)
|
||||
|
||||
def test_extract_gradients_low_level(self) -> None:
|
||||
x = torch.ones((1,))
|
||||
w0 = torch.ones((1,), requires_grad=True)
|
||||
w1 = torch.ones((1,), requires_grad=True)
|
||||
|
||||
def check(cold_start: bool):
|
||||
self.assertEqual(w0.grad is None, cold_start)
|
||||
self.assertEqual(w1.grad is None, cold_start)
|
||||
with profile() as prof:
|
||||
z = x.expand(4) * w0
|
||||
(z * w1).sum().backward()
|
||||
|
||||
# Gradient detection through op inspection does not provide a
|
||||
# reference to the parameter corresponding to the gradient.
|
||||
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
|
||||
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
|
||||
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
|
||||
|
||||
check(cold_start=True)
|
||||
check(cold_start=False)
|
||||
|
||||
def test_extract_gradients_from_module(self) -> None:
|
||||
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
|
||||
named_parameters = {name: p for name, p in model.named_parameters()}
|
||||
self.assertEqual(len(named_parameters), 3)
|
||||
|
||||
def assert_only_gradients(prof: torch.profiler.profile):
|
||||
gradients = tuple(i.grad for i in named_parameters.values())
|
||||
self.assertFalse(any(i is None for i in gradients))
|
||||
self.assertOnlyGradients(prof, gradients)
|
||||
|
||||
def check(cold_start: bool):
|
||||
x = torch.ones((2, 2))
|
||||
with profile() as prof:
|
||||
model(x).sum().backward()
|
||||
|
||||
for name, p in named_parameters.items():
|
||||
# The first time we run a module none of the `.grad` fields
|
||||
# have been initialized. This is fine; in that case we can
|
||||
# detect everything we need in the profiled section.
|
||||
self.assertNotEqual(
|
||||
self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
|
||||
cold_start,
|
||||
name,
|
||||
)
|
||||
|
||||
# Op based detection should still identify the gradients.
|
||||
self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
|
||||
assert_only_gradients(prof)
|
||||
|
||||
# We can detect gradients even when `.backward()` is not called.
|
||||
with profile() as prof:
|
||||
model(torch.ones((2, 2)))
|
||||
|
||||
for name, p in named_parameters.items():
|
||||
self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
|
||||
self.assertFalse(
|
||||
self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
|
||||
)
|
||||
assert_only_gradients(prof)
|
||||
|
||||
check(cold_start=True)
|
||||
check(cold_start=False)
|
||||
|
||||
def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
|
||||
|
||||
x = torch.ones((1,))
|
||||
w0 = torch.ones((1,), requires_grad=True)
|
||||
w1 = torch.ones((1,), requires_grad=True)
|
||||
optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
|
||||
|
||||
def check(cold_start: bool):
|
||||
self.assertEqual(w0.grad is None, cold_start)
|
||||
self.assertEqual(w1.grad is None, cold_start)
|
||||
with profile() as prof:
|
||||
optimizer.zero_grad(set_to_none=set_to_none)
|
||||
z = x.expand(4) * w0
|
||||
(z * w1).sum().backward()
|
||||
optimizer.step()
|
||||
|
||||
# Optimizer instrumentation runs late in the step, so we can detect
|
||||
# gradients for both cold and warm start.
|
||||
self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
|
||||
self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
|
||||
|
||||
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
|
||||
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
|
||||
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
|
||||
|
||||
with profile() as prof:
|
||||
for _ in range(2):
|
||||
optimizer.zero_grad(set_to_none=set_to_none)
|
||||
z = x.expand(4) * w0
|
||||
(z * w1).sum().backward()
|
||||
optimizer.step()
|
||||
|
||||
# Inspected state is cached, so if we replace gradients (as is the
|
||||
# case for `set_to_none=True`) our python instrumentation will not
|
||||
# see them.
|
||||
# TODO(robieta): Should `.step()` be excluded from caching?
|
||||
self.assertNotEqual(
|
||||
self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
|
||||
set_to_none,
|
||||
)
|
||||
|
||||
self.assertNotEqual(
|
||||
self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
|
||||
set_to_none,
|
||||
)
|
||||
|
||||
if set_to_none:
|
||||
with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
|
||||
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
|
||||
|
||||
check(cold_start=True)
|
||||
check(cold_start=False)
|
||||
|
||||
def test_extract_gradients_from_optimizer(self) -> None:
|
||||
self._test_extract_gradients_from_optimizer(set_to_none=False)
|
||||
|
||||
def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
|
||||
self._test_extract_gradients_from_optimizer(set_to_none=True)
|
||||
|
||||
def test_extract_gradients_from_module_and_optimizer(self) -> None:
|
||||
# Module and optimizer are thoroughly tested individually and should be
|
||||
# additive. Thus we can manage with a lightweight check that they don't
|
||||
# interact adversely.
|
||||
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
with profile() as prof:
|
||||
model(torch.ones((2, 2))).sum().backward()
|
||||
optimizer.step()
|
||||
|
||||
self.assertGradientDetected(
|
||||
"weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -3,6 +3,8 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
from torch._C import device, dtype, layout
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
# defined in torch/csrc/profiler/python/init.cpp
|
||||
|
||||
class RecordScope(Enum):
|
||||
@ -38,11 +40,12 @@ class ProfilerActivity(Enum):
|
||||
CUDA = ...
|
||||
|
||||
class _EventType(Enum):
|
||||
Allocation = ...
|
||||
TorchOp = ...
|
||||
Backend = ...
|
||||
Allocation = ...
|
||||
OutOfMemory = ...
|
||||
PyCall = ...
|
||||
PyCCall = ...
|
||||
TorchOp = ...
|
||||
Kineto = ...
|
||||
|
||||
class _ExperimentalConfig:
|
||||
@ -71,6 +74,8 @@ class _ProfilerEvent:
|
||||
start_tid: int
|
||||
start_time_ns: int
|
||||
children: List[_ProfilerEvent]
|
||||
|
||||
# TODO(robieta): remove in favor of `self.typed`
|
||||
extra_fields: Union[
|
||||
_ExtraFields_TorchOp,
|
||||
_ExtraFields_Backend,
|
||||
@ -81,6 +86,18 @@ class _ProfilerEvent:
|
||||
_ExtraFields_Kineto,
|
||||
]
|
||||
|
||||
@property
|
||||
def typed(
|
||||
self,
|
||||
) -> Union[
|
||||
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
|
||||
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
|
||||
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
|
||||
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
|
||||
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
|
||||
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
|
||||
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
|
||||
]: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
@ -101,6 +118,8 @@ class _TensorMetadata:
|
||||
storage_data_ptr: Optional[int]
|
||||
id: Optional[int]
|
||||
|
||||
@property
|
||||
def allocation_id(self) -> Optional[int]: ...
|
||||
@property
|
||||
def layout(self) -> layout: ...
|
||||
@property
|
||||
@ -129,11 +148,12 @@ class _ExtraFields_Backend: ...
|
||||
class _ExtraFields_Allocation:
|
||||
ptr: int
|
||||
id: Optional[int]
|
||||
allocation_id: Optional[int]
|
||||
alloc_size: int
|
||||
total_allocated: int
|
||||
total_reserved: int
|
||||
|
||||
@property
|
||||
def allocation_id(self) -> Optional[int]: ...
|
||||
@property
|
||||
def device(self) -> device: ...
|
||||
|
||||
@ -147,22 +167,47 @@ class _PyFrameState:
|
||||
def file_name(self) -> str: ...
|
||||
|
||||
class _NNModuleInfo:
|
||||
@property
|
||||
def params(self) -> List[Tuple[str, int]]: ...
|
||||
@property
|
||||
def self_ptr(self) -> int: ...
|
||||
@property
|
||||
def cls_ptr(self) -> int: ...
|
||||
@property
|
||||
def cls_name(self) -> str: ...
|
||||
@property
|
||||
def parameters(
|
||||
self,
|
||||
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
|
||||
|
||||
class _OptimizerInfo:
|
||||
@property
|
||||
def parameters(
|
||||
self,
|
||||
) -> List[
|
||||
Tuple[
|
||||
# Parameter
|
||||
_TensorMetadata,
|
||||
#
|
||||
# Gradient (if present during optimizer.step())
|
||||
Optional[_TensorMetadata],
|
||||
#
|
||||
# Optimizer state for Parameter as (name, tensor) pairs
|
||||
List[Tuple[str, _TensorMetadata]],
|
||||
]
|
||||
]: ...
|
||||
|
||||
class _ExtraFields_PyCCall:
|
||||
callsite: _PyFrameState
|
||||
caller: _PyFrameState
|
||||
module: Optional[_NNModuleInfo]
|
||||
@property
|
||||
def caller(self) -> _PyFrameState: ...
|
||||
|
||||
class _ExtraFields_PyCall:
|
||||
caller: _PyFrameState
|
||||
@property
|
||||
def callsite(self) -> _PyFrameState: ...
|
||||
@property
|
||||
def caller(self) -> _PyFrameState: ...
|
||||
@property
|
||||
def module(self) -> Optional[_NNModuleInfo]: ...
|
||||
@property
|
||||
def optimizer(self) -> Optional[_OptimizerInfo]: ...
|
||||
|
||||
class _ExtraFields_Kineto: ...
|
||||
|
||||
|
@ -251,6 +251,13 @@ void initPythonBindings(PyObject* module) {
|
||||
.def_property_readonly("name", &Result::name)
|
||||
.def_property_readonly("tag", &Result::tag)
|
||||
.def_readonly("extra_fields", &Result::extra_fields_)
|
||||
.def_property_readonly(
|
||||
"typed",
|
||||
[](const Result& r) {
|
||||
return py::make_tuple(
|
||||
r.tag(),
|
||||
py::cast(r.extra_fields_, py::return_value_policy::reference));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"id",
|
||||
[](const Result& r) {
|
||||
|
114
torch/profiler/_memory_profiler.py
Normal file
114
torch/profiler/_memory_profiler.py
Normal file
@ -0,0 +1,114 @@
|
||||
import dataclasses
|
||||
from typing import Any, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata, RecordScope
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Storage:
|
||||
"""Bundle storage pointer and id.
|
||||
|
||||
All profiling logic should use `allocation_id`, however it is useful to
|
||||
print storage pointers for debugging and unit tests sometimes look up
|
||||
values using the storage data pointer of a live Tensor."""
|
||||
|
||||
ptr: int
|
||||
allocation_id: int
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{hex(self.ptr):>18} ({self.allocation_id})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.allocation_id)
|
||||
|
||||
|
||||
@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
|
||||
class TensorKey:
|
||||
"""Hashable identifier for a storage which has been asigned an ID.
|
||||
|
||||
A detailed description of Tensor IDs and why they are needed is given in
|
||||
`torch/csrc/profiler/collection.h` when `TensorID` is declared. To
|
||||
summarize, multiple Storage buffers can map to the same logical Tensor.
|
||||
This dataclass is used to refer to a concrete in-memory StorageImpl of
|
||||
a Tensor.
|
||||
"""
|
||||
|
||||
id: int
|
||||
storage: _Storage
|
||||
device: torch.device
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
|
||||
|
||||
@staticmethod
|
||||
def _make(
|
||||
tensor_id: Optional[int],
|
||||
storage_ptr: Optional[int],
|
||||
allocation_id: Optional[int],
|
||||
device: torch.device,
|
||||
) -> Optional["TensorKey"]:
|
||||
if (
|
||||
tensor_id is not None
|
||||
and storage_ptr is not None
|
||||
and allocation_id is not None
|
||||
):
|
||||
return TensorKey(tensor_id, _Storage(storage_ptr, allocation_id), device)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
|
||||
if t is not None:
|
||||
return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
|
||||
return None
|
||||
|
||||
|
||||
def extract_gradients(
|
||||
node: _ProfilerEvent,
|
||||
) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
|
||||
children = node.children
|
||||
|
||||
# AccumulateGrad is used in the Autograd engine to handle gradient updates.
|
||||
# There are two possible cases:
|
||||
# 1) This is a newly created gradient Tensor. In that case there is nothing
|
||||
# to accumulate, so autograd simply detaches the Tensor.
|
||||
#
|
||||
# 2) There is a preexisting gradient Tensor and we need to add the newly
|
||||
# computed update. This is done with an in-place add (aten::add_) op.
|
||||
# (The underscore suffix denotes "in-place".)
|
||||
if (
|
||||
node.typed[0] == _EventType.TorchOp
|
||||
and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
|
||||
# TODO(robieta): Move away from load bearing names
|
||||
and node.name == "torch::autograd::AccumulateGrad"
|
||||
and children
|
||||
and children[0].typed[0] == _EventType.TorchOp
|
||||
and children[0].name in ("aten::detach", "aten::add_")
|
||||
and children[0].typed[1].inputs
|
||||
and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
|
||||
):
|
||||
key = TensorKey.from_tensor(children[0].typed[1].inputs[0])
|
||||
if key:
|
||||
yield None, key
|
||||
|
||||
# We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
|
||||
# NOTE: The values captured by the python tracer are cached; they can be
|
||||
# used to build up labels but do not imply that a Tensor was live at
|
||||
# a particular time.
|
||||
elif node.typed[0] == _EventType.PyCall:
|
||||
typed_fields = node.typed[1]
|
||||
assert typed_fields.module is None or typed_fields.optimizer is None
|
||||
if typed_fields.module is not None:
|
||||
for _, p, p_grad in typed_fields.module.parameters:
|
||||
p_grad_key = TensorKey.from_tensor(p_grad)
|
||||
if p_grad_key is not None:
|
||||
yield TensorKey.from_tensor(p), p_grad_key
|
||||
|
||||
if typed_fields.optimizer is not None:
|
||||
for p, p_grad, _ in typed_fields.optimizer.parameters:
|
||||
p_grad_key = TensorKey.from_tensor(p_grad)
|
||||
if p_grad_key is not None:
|
||||
yield TensorKey.from_tensor(p), p_grad_key
|
Reference in New Issue
Block a user