[Profiler] Memory profiler part 1: Gradient identification (#86802)

There are multiple ways to indentify that a Tensor is a gradient. (A subset of which also give additional context.) So to start off I've made a utility to handle that determination.

Differential Revision: [D39920730](https://our.internmc.facebook.com/intern/diff/D39920730/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86802
Approved by: https://github.com/chaekit
This commit is contained in:
Taylor Robie
2022-11-07 15:48:35 -08:00
committed by PyTorch MergeBot
parent c0e6b4329f
commit cef13ebea0
5 changed files with 400 additions and 9 deletions

View File

@ -40,6 +40,7 @@ files =
.github,
benchmarks/instruction_counts,
tools,
torch/profiler/_memory_profiler.py,
torch/utils/_pytree.py,
torch/utils/benchmark/utils/common.py,
torch/utils/benchmark/utils/timer.py,

View File

@ -0,0 +1,224 @@
# Owner(s): ["oncall: profiler"]
import functools
from typing import Iterator, Optional
import torch
from torch._C._profiler import _EventType
from torch.profiler import _memory_profiler, _utils
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
profile = functools.partial(
torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
)
class ScaleLayer(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
class TestIdentifyGradients(TestCase):
def gradient_detected(
self,
prof: torch.profiler.profile,
ctx: _EventType,
grad_tensor: torch.Tensor,
parameter: Optional[torch.Tensor] = None,
) -> None:
# This is not an exhaustive check, but for the purpose of unit testing
# it is sufficient.
def key_matches_tensor(key, tensor) -> bool:
# Vacuous case.
if tensor is None:
return True
if key is None:
return False
return tensor.storage().data_ptr() == key.storage.ptr
tree = prof.profiler.kineto_results.experimental_event_tree()
for node in _utils.traverse_dfs(tree):
for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
if parameter is None:
return True # Don't need to check parameter; we're done.
elif p_key is not None:
# For a complex workflow a gradient could correspond to
# different parameters at different points in a trace.
# However this will not happen in the relatively simple
# cases tested here, so if `extract_gradients` identifies
# the parameter corresponding to a particular gradient it
# must be the one we expect.
self.assertTrue(key_matches_tensor(p_key, parameter))
return True
return False
def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
self.assertTrue(
self.gradient_detected(*args, **kwargs),
f"Failed to identify gradient `{name}` from profile.",
)
def assertOnlyGradients(
self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
) -> None:
allowed_set = {t.storage().data_ptr() for t in tensors}
tree = prof.profiler.kineto_results.experimental_event_tree()
for node in _utils.traverse_dfs(tree):
for _, p_grad_key in _memory_profiler.extract_gradients(node):
self.assertTrue(
p_grad_key.storage.ptr in allowed_set,
f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
)
def test_extract_gradients_low_level(self) -> None:
x = torch.ones((1,))
w0 = torch.ones((1,), requires_grad=True)
w1 = torch.ones((1,), requires_grad=True)
def check(cold_start: bool):
self.assertEqual(w0.grad is None, cold_start)
self.assertEqual(w1.grad is None, cold_start)
with profile() as prof:
z = x.expand(4) * w0
(z * w1).sum().backward()
# Gradient detection through op inspection does not provide a
# reference to the parameter corresponding to the gradient.
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
check(cold_start=True)
check(cold_start=False)
def test_extract_gradients_from_module(self) -> None:
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
named_parameters = {name: p for name, p in model.named_parameters()}
self.assertEqual(len(named_parameters), 3)
def assert_only_gradients(prof: torch.profiler.profile):
gradients = tuple(i.grad for i in named_parameters.values())
self.assertFalse(any(i is None for i in gradients))
self.assertOnlyGradients(prof, gradients)
def check(cold_start: bool):
x = torch.ones((2, 2))
with profile() as prof:
model(x).sum().backward()
for name, p in named_parameters.items():
# The first time we run a module none of the `.grad` fields
# have been initialized. This is fine; in that case we can
# detect everything we need in the profiled section.
self.assertNotEqual(
self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
cold_start,
name,
)
# Op based detection should still identify the gradients.
self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
assert_only_gradients(prof)
# We can detect gradients even when `.backward()` is not called.
with profile() as prof:
model(torch.ones((2, 2)))
for name, p in named_parameters.items():
self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
self.assertFalse(
self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
)
assert_only_gradients(prof)
check(cold_start=True)
check(cold_start=False)
def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
x = torch.ones((1,))
w0 = torch.ones((1,), requires_grad=True)
w1 = torch.ones((1,), requires_grad=True)
optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
def check(cold_start: bool):
self.assertEqual(w0.grad is None, cold_start)
self.assertEqual(w1.grad is None, cold_start)
with profile() as prof:
optimizer.zero_grad(set_to_none=set_to_none)
z = x.expand(4) * w0
(z * w1).sum().backward()
optimizer.step()
# Optimizer instrumentation runs late in the step, so we can detect
# gradients for both cold and warm start.
self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
with profile() as prof:
for _ in range(2):
optimizer.zero_grad(set_to_none=set_to_none)
z = x.expand(4) * w0
(z * w1).sum().backward()
optimizer.step()
# Inspected state is cached, so if we replace gradients (as is the
# case for `set_to_none=True`) our python instrumentation will not
# see them.
# TODO(robieta): Should `.step()` be excluded from caching?
self.assertNotEqual(
self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
set_to_none,
)
self.assertNotEqual(
self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
set_to_none,
)
if set_to_none:
with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
check(cold_start=True)
check(cold_start=False)
def test_extract_gradients_from_optimizer(self) -> None:
self._test_extract_gradients_from_optimizer(set_to_none=False)
def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
self._test_extract_gradients_from_optimizer(set_to_none=True)
def test_extract_gradients_from_module_and_optimizer(self) -> None:
# Module and optimizer are thoroughly tested individually and should be
# additive. Thus we can manage with a lightweight check that they don't
# interact adversely.
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
with profile() as prof:
model(torch.ones((2, 2))).sum().backward()
optimizer.step()
self.assertGradientDetected(
"weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
)
if __name__ == "__main__":
run_tests()

View File

@ -3,6 +3,8 @@ from typing import List, Optional, Tuple, Union
from torch._C import device, dtype, layout
from typing_extensions import Literal
# defined in torch/csrc/profiler/python/init.cpp
class RecordScope(Enum):
@ -38,11 +40,12 @@ class ProfilerActivity(Enum):
CUDA = ...
class _EventType(Enum):
Allocation = ...
TorchOp = ...
Backend = ...
Allocation = ...
OutOfMemory = ...
PyCall = ...
PyCCall = ...
TorchOp = ...
Kineto = ...
class _ExperimentalConfig:
@ -71,6 +74,8 @@ class _ProfilerEvent:
start_tid: int
start_time_ns: int
children: List[_ProfilerEvent]
# TODO(robieta): remove in favor of `self.typed`
extra_fields: Union[
_ExtraFields_TorchOp,
_ExtraFields_Backend,
@ -81,6 +86,18 @@ class _ProfilerEvent:
_ExtraFields_Kineto,
]
@property
def typed(
self,
) -> Union[
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
]: ...
@property
def name(self) -> str: ...
@property
@ -101,6 +118,8 @@ class _TensorMetadata:
storage_data_ptr: Optional[int]
id: Optional[int]
@property
def allocation_id(self) -> Optional[int]: ...
@property
def layout(self) -> layout: ...
@property
@ -129,11 +148,12 @@ class _ExtraFields_Backend: ...
class _ExtraFields_Allocation:
ptr: int
id: Optional[int]
allocation_id: Optional[int]
alloc_size: int
total_allocated: int
total_reserved: int
@property
def allocation_id(self) -> Optional[int]: ...
@property
def device(self) -> device: ...
@ -147,22 +167,47 @@ class _PyFrameState:
def file_name(self) -> str: ...
class _NNModuleInfo:
@property
def params(self) -> List[Tuple[str, int]]: ...
@property
def self_ptr(self) -> int: ...
@property
def cls_ptr(self) -> int: ...
@property
def cls_name(self) -> str: ...
@property
def parameters(
self,
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
class _OptimizerInfo:
@property
def parameters(
self,
) -> List[
Tuple[
# Parameter
_TensorMetadata,
#
# Gradient (if present during optimizer.step())
Optional[_TensorMetadata],
#
# Optimizer state for Parameter as (name, tensor) pairs
List[Tuple[str, _TensorMetadata]],
]
]: ...
class _ExtraFields_PyCCall:
callsite: _PyFrameState
caller: _PyFrameState
module: Optional[_NNModuleInfo]
@property
def caller(self) -> _PyFrameState: ...
class _ExtraFields_PyCall:
caller: _PyFrameState
@property
def callsite(self) -> _PyFrameState: ...
@property
def caller(self) -> _PyFrameState: ...
@property
def module(self) -> Optional[_NNModuleInfo]: ...
@property
def optimizer(self) -> Optional[_OptimizerInfo]: ...
class _ExtraFields_Kineto: ...

View File

@ -251,6 +251,13 @@ void initPythonBindings(PyObject* module) {
.def_property_readonly("name", &Result::name)
.def_property_readonly("tag", &Result::tag)
.def_readonly("extra_fields", &Result::extra_fields_)
.def_property_readonly(
"typed",
[](const Result& r) {
return py::make_tuple(
r.tag(),
py::cast(r.extra_fields_, py::return_value_policy::reference));
})
.def_property_readonly(
"id",
[](const Result& r) {

View File

@ -0,0 +1,114 @@
import dataclasses
from typing import Any, Iterator, Optional, Tuple
import torch
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata, RecordScope
@dataclasses.dataclass
class _Storage:
"""Bundle storage pointer and id.
All profiling logic should use `allocation_id`, however it is useful to
print storage pointers for debugging and unit tests sometimes look up
values using the storage data pointer of a live Tensor."""
ptr: int
allocation_id: int
def __repr__(self) -> str:
return f"{hex(self.ptr):>18} ({self.allocation_id})"
def __eq__(self, other: Any) -> bool:
return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
def __hash__(self) -> int:
return hash(self.allocation_id)
@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
class TensorKey:
"""Hashable identifier for a storage which has been asigned an ID.
A detailed description of Tensor IDs and why they are needed is given in
`torch/csrc/profiler/collection.h` when `TensorID` is declared. To
summarize, multiple Storage buffers can map to the same logical Tensor.
This dataclass is used to refer to a concrete in-memory StorageImpl of
a Tensor.
"""
id: int
storage: _Storage
device: torch.device
def __repr__(self) -> str:
return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
@staticmethod
def _make(
tensor_id: Optional[int],
storage_ptr: Optional[int],
allocation_id: Optional[int],
device: torch.device,
) -> Optional["TensorKey"]:
if (
tensor_id is not None
and storage_ptr is not None
and allocation_id is not None
):
return TensorKey(tensor_id, _Storage(storage_ptr, allocation_id), device)
return None
@classmethod
def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
if t is not None:
return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
return None
def extract_gradients(
node: _ProfilerEvent,
) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
children = node.children
# AccumulateGrad is used in the Autograd engine to handle gradient updates.
# There are two possible cases:
# 1) This is a newly created gradient Tensor. In that case there is nothing
# to accumulate, so autograd simply detaches the Tensor.
#
# 2) There is a preexisting gradient Tensor and we need to add the newly
# computed update. This is done with an in-place add (aten::add_) op.
# (The underscore suffix denotes "in-place".)
if (
node.typed[0] == _EventType.TorchOp
and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
# TODO(robieta): Move away from load bearing names
and node.name == "torch::autograd::AccumulateGrad"
and children
and children[0].typed[0] == _EventType.TorchOp
and children[0].name in ("aten::detach", "aten::add_")
and children[0].typed[1].inputs
and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
):
key = TensorKey.from_tensor(children[0].typed[1].inputs[0])
if key:
yield None, key
# We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
# NOTE: The values captured by the python tracer are cached; they can be
# used to build up labels but do not imply that a Tensor was live at
# a particular time.
elif node.typed[0] == _EventType.PyCall:
typed_fields = node.typed[1]
assert typed_fields.module is None or typed_fields.optimizer is None
if typed_fields.module is not None:
for _, p, p_grad in typed_fields.module.parameters:
p_grad_key = TensorKey.from_tensor(p_grad)
if p_grad_key is not None:
yield TensorKey.from_tensor(p), p_grad_key
if typed_fields.optimizer is not None:
for p, p_grad, _ in typed_fields.optimizer.parameters:
p_grad_key = TensorKey.from_tensor(p_grad)
if p_grad_key is not None:
yield TensorKey.from_tensor(p), p_grad_key