Files
pytorch/test/profiler/test_torch_tidy.py
2025-07-25 02:56:34 +00:00

912 lines
33 KiB
Python

# Owner(s): ["oncall: profiler"]
import gc
import re
import textwrap
import unittest
import weakref
from typing import Any
import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
from torch._C._profiler import _ExtraFields_PyCall, _TensorMetadata
from torch.profiler import _utils, profile
from torch.testing._internal.common_utils import run_tests, TestCase
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
pass
Json = dict[str, Any]
def find_node_with_name(nodes, name):
for node in _utils.traverse_dfs(nodes):
if node.name == name:
return node
def find_node_with_regex(nodes, pattern):
for node in _utils.traverse_dfs(nodes):
if re.search(pattern, node.name):
return node
class SimpleNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
return self.fc2(self.fc1(x))
class TestTorchTidyProfiler(TestCase):
def _get_tensor_fields(self, node, index):
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
tensor_info = node.extra_fields.inputs[index]
self.assertIsInstance(tensor_info, _TensorMetadata)
self.assertIsNotNone(tensor_info.impl_ptr)
self.assertIsNotNone(tensor_info.storage_data_ptr)
self.assertIsNotNone(tensor_info.id)
return tensor_info.impl_ptr, tensor_info.storage_data_ptr, tensor_info.id
def test_pointers_and_ids(self):
a = torch.randn(4, 3)
a_initial_storage_data = a.storage().data_ptr()
# Views of tensors can share the same storage, but have different TensorImpls
b = a.view((1, 12))
c = torch.randn(4, 1)
c_initial_storage_data = c.storage().data_ptr()
d = torch.randn(4, 3)
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = a + c
_ = b * c
# Resize should create a new data_ptr but keep the TensorImpl the same.
f = a.resize_(128, 129)
_ = torch.relu(f)
# `.set_` points a Tensor at an existing storage.
_ = d.sin()
c.set_(d.storage())
_ = c.cos()
nodes = p.profiler.kineto_results.experimental_event_tree()
def get_fields(op_name, index):
return self._get_tensor_fields(find_node_with_name(nodes, op_name), index)
a_impl, a_storage_data, a_id = get_fields("aten::add", 0)
b_impl, b_storage_data, _ = get_fields("aten::mul", 0)
# Profiler matches ground truth from Python API.
self.assertEqual(a_storage_data, a_initial_storage_data)
# Views are handled correctly.
self.assertEqual(a_storage_data, b_storage_data)
self.assertNotEqual(a_impl, b_impl)
# The same Tensor used in multiple calls gives identical results.
c_impl, c_storage_data, c_id = get_fields("aten::add", 1)
self.assertEqual((c_impl, c_storage_data, c_id), get_fields("aten::mul", 1))
self.assertEqual(c_storage_data, c_initial_storage_data)
# Mutations to the underlying storage are reflected. (But ID is shared.)
f_impl, f_storage_data, f_id = get_fields("aten::relu", 0)
self.assertEqual(a_impl, f_impl)
self.assertNotEqual(a_storage_data, f_storage_data)
self.assertEqual(a_id, f_id)
# Calling `set_` with an existing Tensor makes them share an ID.
d_impl, d_storage_data, d_id = get_fields("aten::sin", 0)
c_impl_new, c_storage_data_new, c_id_new = get_fields("aten::cos", 0)
self.assertNotEqual(d_impl, c_impl_new)
self.assertEqual(d_storage_data, c_storage_data_new)
self.assertEqual(c_id, c_id_new)
self.assertEqual(d_id, c_id_new)
@staticmethod
def _format_allocations(profiled_code):
gc.collect()
with profile(profile_memory=True, record_shapes=True) as prof:
profiled_code()
gc.collect()
root_events = prof.profiler.kineto_results.experimental_event_tree()
events = sorted(_utils.traverse_dfs(root_events), key=lambda x: x.start_time_ns)
allocations = tuple(
event.extra_fields
for event in events
if isinstance(
event.extra_fields, torch._C._profiler._ExtraFields_Allocation
)
)
return textwrap.indent(
"\n".join(
f"{repr(i.id):>5}{' ' * 6}"
f"{repr(i.allocation_id):>5}{' ' * 6}"
f"{'Allocation' if i.alloc_size > 0 else 'Free'}"
for i in allocations
),
" " * 12,
)
def test_tensorimpl_invalidation_set(self) -> None:
def profiled_code(add_empty_set: bool):
x = torch.ones((1,))
# Determines if new storage is created before or after the old one
# is destroyed.
if add_empty_set:
x.set_()
x.set_(torch.ones((1,)).storage())
x.view_as(x)
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=False)),
"""\
0 1 Allocation
0 2 Allocation
0 1 Free
0 2 Free""",
)
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=True)),
"""\
0 1 Allocation
0 1 Free
0 2 Allocation
0 2 Free""",
)
def test_tensorimpl_invalidation_keep_alive(self) -> None:
def profiled_code(add_empty_set: bool):
x = torch.ones((1,))
x_storages = [x.storage()]
for _ in range(3):
x.set_()
x.set_(torch.ones((1,)).storage())
# This keeps the StorageImpls alive and preserves the chain.
# (Despite the `set_()` call.)
x_storages.append(x.storage())
x.view_as(x)
# Free storage in a deterministic fashion.
while x_storages:
x_storages.pop()
gc.collect()
# Determines if new storage is created before or after the old one
# is destroyed.
if add_empty_set:
x.set_()
for _ in range(3):
x.set_(torch.ones((1,)).storage())
x.view_as(x)
del x
gc.collect()
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=False)),
"""\
0 1 Allocation
0 2 Allocation
0 4 Allocation
0 5 Allocation
0 4 Free
0 2 Free
0 1 Free
0 6 Allocation
0 5 Free
0 7 Allocation
0 6 Free
0 8 Allocation
0 7 Free
0 8 Free""",
)
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=True)),
"""\
0 1 Allocation
0 2 Allocation
0 4 Allocation
0 5 Allocation
0 4 Free
0 2 Free
0 1 Free
0 5 Free
0 6 Allocation
0 7 Allocation
0 6 Free
0 8 Allocation
0 7 Free
0 8 Free""",
)
def test_tensorimpl_invalidation_full(self) -> None:
def profiled_code():
x = torch.ones((1,))
x_storages = [x.storage()]
for _ in range(3):
x.set_()
x.set_(torch.ones((1,)).storage())
x_storages.append(x.storage())
x.view_as(x)
# Free storage in a deterministic fashion.
while x_storages:
x_storages.pop()
gc.collect()
for _ in range(3):
x.set_(torch.ones((1,)).storage())
for _ in range(3):
x.set_()
x.set_(torch.ones((1,)).storage())
for i in range(4):
x.resize_((1 + i,))
x.view_as(x)
self.assertExpectedInline(
self._format_allocations(profiled_code),
"""\
0 1 Allocation
0 2 Allocation
0 4 Allocation
0 5 Allocation
0 4 Free
0 2 Free
0 1 Free
0 6 Allocation
0 5 Free
0 7 Allocation
0 6 Free
0 8 Allocation
0 7 Free
0 8 Free
0 9 Allocation
0 9 Free
0 10 Allocation
0 10 Free
0 11 Allocation
0 12 Allocation
0 11 Free
0 13 Allocation
0 12 Free
0 14 Allocation
0 13 Free
0 14 Free""",
)
def test_tensorimpl_invalidation_scalar_args(self) -> None:
def profiled_code():
with torch.no_grad():
x = torch.ones((1,))
for _ in range(10):
x.add_(2)
self.assertExpectedInline(
self._format_allocations(profiled_code),
"""\
0 1 Allocation
1 2 Allocation
2 3 Allocation
2 3 Free
1 2 Free
3 4 Allocation
4 5 Allocation
4 5 Free
3 4 Free
5 6 Allocation
6 7 Allocation
6 7 Free
5 6 Free
7 8 Allocation
8 9 Allocation
8 9 Free
7 8 Free
9 10 Allocation
10 11 Allocation
10 11 Free
9 10 Free
11 12 Allocation
12 13 Allocation
12 13 Free
11 12 Free
13 14 Allocation
14 15 Allocation
14 15 Free
13 14 Free
15 16 Allocation
16 17 Allocation
16 17 Free
15 16 Free
17 18 Allocation
18 19 Allocation
18 19 Free
17 18 Free
19 20 Allocation
20 21 Allocation
20 21 Free
19 20 Free
0 1 Free""",
)
def test_module_and_optimizer_ids(self) -> None:
model = torch.nn.Linear(2, 1, bias=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
def check(cold_start: bool) -> None:
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
x = torch.ones((1, 2))
_ = x.sin() # Mark `x`
model(x).backward()
optimizer.step()
_ = optimizer.state[model.weight][
"momentum_buffer"
].cos() # Mark weight momentum
_ = model.weight.grad.tan() # Mark weight gradient
nodes = p.profiler.kineto_results.experimental_event_tree()
def get_fields(op_name, index):
return self._get_tensor_fields(
find_node_with_name(nodes, op_name), index
)
# Marked Tensors act as ground truth for python tracer IDs.
_, _, x_id = get_fields("aten::sin", 0)
_, _, weight_momenumtum_id = get_fields("aten::cos", 0)
_, _, weight_grad_id = get_fields("aten::tan", 0)
self.assertNotEqual(x_id, weight_momenumtum_id)
self.assertNotEqual(x_id, weight_grad_id)
self.assertNotEqual(weight_momenumtum_id, weight_grad_id)
# Use linear op to identify weight ground truth.
linear_op_node = find_node_with_name(nodes, "aten::linear")
self.assertIsNotNone(linear_op_node)
x_metadata, weight_metadata, _ = linear_op_node.extra_fields.inputs
self.assertEqual(x_id, x_metadata.id)
# Module
linear_module_node = find_node_with_name(nodes, "nn.Module: Linear_0")
self.assertIsNotNone(linear_module_node)
self.assertIsNotNone(linear_module_node.extra_fields.module)
self.assertIsNone(linear_module_node.extra_fields.optimizer)
linear_parameters = linear_module_node.extra_fields.module.parameters
name, weight, weight_grad = linear_parameters[0]
self.assertEqual(name, "weight")
self.assertEqual(weight.id, weight_metadata.id)
self.assertEqual(weight_grad is None, cold_start)
if not cold_start:
self.assertEqual(weight_grad.id, weight_grad_id)
# Optimizer
step_node = find_node_with_regex(nodes, "_optimizer_step_code")
self.assertIsNotNone(step_node)
self.assertIsNone(step_node.extra_fields.module)
self.assertIsNotNone(step_node.extra_fields.optimizer)
optimizer_parameters = step_node.extra_fields.optimizer.parameters
self.assertEqual(len(optimizer_parameters), 2) # Weight and bias
weight, weight_grad, state = optimizer_parameters[0]
self.assertEqual(weight.id, weight_metadata.id)
self.assertEqual(weight_grad.id, weight_grad_id)
self.assertEqual(len(state), 1)
self.assertEqual(state[0][0], "momentum_buffer")
self.assertEqual(state[0][1].id, weight_momenumtum_id)
# Check that we handle first step (lazy initalization) and steady state.
check(cold_start=True)
check(cold_start=False)
def _test_allocation_ids(self, before_fn, after_fn) -> None:
with profile(profile_memory=True, record_shapes=True) as p:
# Introduce other operations and allocations to check robustness
_ = before_fn()
x = torch.rand(4, 3)
x.resize_(4, 4)
# We need to use `x` post resize for profiler to determine its ID.
x.sin()
# Introduce other operations and allocations to check robustness
_ = after_fn()
# Ensure `x` is the last variable collected to make it easier to
# find the deallocation event.
gc.collect()
del x
gc.collect()
nodes = p.profiler.kineto_results.experimental_event_tree()
def find_chain(names: list[str]):
out = []
for name in names:
root = [out[-1]] if out else nodes
out.append(find_node_with_name(root, name))
self.assertIsNotNone(out[-1], name)
return out
allocation = find_chain(["aten::rand", "aten::empty", "[memory]"])[
-1
].extra_fields
_, uniform_node = find_chain(["aten::rand", "aten::uniform_"])
_, x_storage_data, x_id = self._get_tensor_fields(uniform_node, 0)
# Make sure IDs are consistent between allocations and op inputs
self.assertEqual(allocation.ptr, x_storage_data)
self.assertEqual(allocation.id, x_id)
resize_node = find_node_with_name(nodes, "aten::resize_")
self.assertIsNotNone(resize_node)
self.assertEqual(len(resize_node.children), 2)
allocate_new = resize_node.children[0].extra_fields
free_old = resize_node.children[1].extra_fields
# Destruction of the old storage for x.
self.assertEqual(free_old.id, allocation.id)
self.assertEqual(free_old.ptr, allocation.ptr)
# Make sure ID is retained through change in storage.
self.assertEqual(allocate_new.id, allocation.id)
self.assertNotEqual(allocate_new.ptr, allocation.ptr)
# Deletion when `x` goes out of scope.
free_new = [
i for i in nodes if i.tag == torch._C._profiler._EventType.Allocation
][-1].extra_fields
self.assertIsInstance(free_new, torch._C._profiler._ExtraFields_Allocation)
self.assertEqual(free_new.id, allocate_new.id)
self.assertEqual(free_new.ptr, allocate_new.ptr)
def test_allocation_ids(self) -> None:
self._test_allocation_ids(lambda: None, lambda: None)
def test_allocation_ids_with_other_ops(self) -> None:
x = torch.ones((1,))
self._test_allocation_ids(
lambda: (x + 1).relu_(), lambda: torch.zeros((1,)).cos()
)
def test_impl_reuse(self) -> None:
repeats = 1_000
with profile(profile_memory=True, record_shapes=True) as p:
for _ in range(repeats):
torch.ones((1,))
gc.collect()
roots = p.profiler.kineto_results.experimental_event_tree()
tensor_impls = tuple(
e.extra_fields.inputs[0].impl_ptr
for e in _utils.traverse_dfs(roots)
if e.name == "aten::fill_"
)
self.assertEqual(len(tensor_impls), repeats)
self.assertEqual(len(set(tensor_impls)), repeats)
def test_allocation_id_uniqueness(self) -> None:
repeats = 1_000
with profile(profile_memory=True, record_shapes=True) as p:
for _ in range(repeats):
torch.ones((1,))
gc.collect()
roots = p.profiler.kineto_results.experimental_event_tree()
id_set = set()
for e in _utils.traverse_dfs(roots):
fields = e.extra_fields
if isinstance(fields, torch._C._profiler._ExtraFields_TorchOp):
id_set |= {
t.allocation_id
for t in fields.inputs
if isinstance(t, _TensorMetadata)
}
elif isinstance(fields, torch._C._profiler._ExtraFields_Allocation):
id_set.add(fields.allocation_id)
id_set.difference_update([None])
self.assertEqual(repeats, len(id_set))
def test_extra_fields(self):
with profile(with_stack=True, profile_memory=True) as p:
_ = torch.ones((1,))
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::ones")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
self.assertIsInstance(
node.parent.extra_fields, torch._C._profiler._ExtraFields_PyCCall
)
self.assertEqual(node.children[0].name, "aten::empty")
self.assertEqual(node.children[0].children[0].name, "[memory]")
self.assertIsInstance(
node.children[0].children[0].extra_fields,
torch._C._profiler._ExtraFields_Allocation,
)
def test_tensor_properties(self):
x = torch.ones(10, 10).as_strided([4, 4], [12, 3])
y = torch.ones(4, 1, requires_grad=True)
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = x + y
_ = x * y
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
self.assertEqual(getattr_inputs("sizes", []), [[4, 4], [4, 1], []])
self.assertEqual(getattr_inputs("strides", []), [[12, 3], [1, 1], []])
self.assertEqual(
getattr_inputs("layout", None), [torch.strided, torch.strided, None]
)
self.assertEqual(
getattr_inputs("device", None),
[torch.device("cpu"), torch.device("cpu"), None],
)
self.assertEqual(
getattr_inputs("dtype", None), [torch.float32, torch.float32, None]
)
self.assertEqual(node.extra_fields.scope, torch.profiler.RecordScope.FUNCTION)
mul_node = find_node_with_name(nodes, "aten::mul")
self.assertIsNotNone(mul_node)
self.assertEqual(
node.extra_fields.sequence_number + 1, mul_node.extra_fields.sequence_number
)
def test_sparse_tensors(self):
i = [[0, 1, 1], [2, 0, 2]]
v = [3, 4, 5]
s = torch.sparse_coo_tensor(i, v, (2, 3))
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = s + s
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
self.assertEqual(getattr_inputs("sizes", []), [[2, 3], [2, 3], []])
self.assertEqual(getattr_inputs("strides", []), [[], [], []])
self.assertEqual(
getattr_inputs("layout", None), [torch.sparse_coo, torch.sparse_coo, None]
)
self.assertEqual(
getattr_inputs("device", None),
[torch.device("cpu"), torch.device("cpu"), None],
)
@unittest.skipIf(
not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
)
def test_mkldnn_tensors(self):
x = torch.ones(4, 3).to_mkldnn()
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = x + x
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
self.assertEqual(getattr_inputs("sizes", []), [[4, 3], [4, 3], []])
self.assertEqual(getattr_inputs("strides", []), [[], [], []])
self.assertEqual(
getattr_inputs("layout", None), [torch._mkldnn, torch._mkldnn, None]
)
self.assertEqual(
getattr_inputs("device", None),
[torch.device("cpu"), torch.device("cpu"), None],
)
def test_scalar_ins(self):
x = torch.ones(5, 5)
alpha = 0.9
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = torch.add(x, 9.1, alpha=alpha)
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
# The second argument to the add gets promotoed to a zerodim Tensor
self.assertEqual(
getattr_inputs("dtype", None), [torch.float32, torch.float64, None]
)
self.assertEqual(getattr_inputs("sizes", []), [[5, 5], [], []])
self.assertEqual(node.extra_fields.inputs[2], alpha)
def test_tensor_lists(self):
x = torch.ones((1,))
y = torch.ones((1,))
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = torch.stack((x, y))
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::stack")
inputs = node.extra_fields.inputs
self.assertEqual(len(inputs), 2)
self.assertIsInstance(inputs[0], list)
self.assertEqual(len(inputs[0]), 2)
self.assertEqual(x.storage().data_ptr(), inputs[0][0].storage_data_ptr)
self.assertEqual(y.storage().data_ptr(), inputs[0][1].storage_data_ptr)
def test_nnmodule_params(self):
def flat_out_extrafields(nodes, out=None):
if out is None:
out = []
for node in nodes:
if (
isinstance(node.extra_fields, _ExtraFields_PyCall)
and node.extra_fields.module
):
if node.extra_fields.module.parameters:
out.append(node.extra_fields.module)
flat_out_extrafields(node.children, out)
return out
inputs = torch.rand(10)
net = SimpleNet()
out = net(inputs)
torch.nn.functional.cross_entropy(out, torch.rand(2)).backward()
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
_ = net(inputs)
modules = flat_out_extrafields(
p.profiler.kineto_results.experimental_event_tree()
)
self.assertEqual(
len(modules), 2, f"Expected two parameter list, but got {len(modules)}"
)
params = [
(n, p.storage_data_ptr, g.storage_data_ptr)
for module in modules
for (n, p, g) in module.parameters
]
expected = [
(name, val.storage().data_ptr(), val.grad.storage().data_ptr())
for name, val in net.fc1._parameters.items()
]
expected += [
(name, val.storage().data_ptr(), val.grad.storage().data_ptr())
for name, val in net.fc2._parameters.items()
]
self.assertEqual(expected, params, f"{expected} vs. {params}")
def _flat_out_extrafields(self, nodes, out=None):
if out is None:
out = []
for node in nodes:
if (
isinstance(node.extra_fields, _ExtraFields_PyCall)
and node.extra_fields.optimizer
and node.extra_fields.optimizer.parameters
):
# avoiding OptInfo duplicates from iterations
addr = node.extra_fields.optimizer.parameters[0][0].storage_data_ptr
if not [o for o in out if addr == o.parameters[0][0].storage_data_ptr]:
out.append(node.extra_fields.optimizer)
self._flat_out_extrafields(node.children, out)
return out
def _check_results(self, opt, opts, check_items=False):
self.assertEqual(len(opts), 1, f"Expected 1 optimizer: len(opts): {len(opts)}")
self.assertEqual(
id(opt),
opts[0].self_ptr,
f"Optimizer addr ({id(opt)}) vs. profiled addr ({opts[0].self_ptr})",
)
if check_items:
self.assertEqual(len(opt.param_groups), len(opts))
for group, opt_ in zip(opt.param_groups, opts):
self.assertEqual(
[(v.storage().data_ptr()) for v in group.get("params", [])],
[(o.storage_data_ptr) for (o, _, _) in opt_.parameters],
)
for opt_ in opts:
observed_state = {
p.storage_data_ptr: {name: s.storage_data_ptr for name, s in state}
for (p, _, state) in opt_.parameters
}
# Make sure the profiler collected all optimizer state and check
# that the address recorded by the profiler is correct.
for parameter, parameter_state in opt.state.items():
self.assertEqual(
{
name: value.storage().data_ptr()
for name, value in parameter_state.items()
},
observed_state.get(parameter.storage().data_ptr(), []),
)
def test_optimizer(self):
inputs = torch.rand(10)
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
net = SimpleNet()
opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
opt.zero_grad()
out = net(inputs)
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
loss.backward()
opt.step()
self._check_results(
opt,
self._flat_out_extrafields(
p.profiler.kineto_results.experimental_event_tree()
),
False,
)
def _test_optimizer_parameters(self, optimizer_factory):
inputs = torch.rand(10)
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
net = SimpleNet()
opt = optimizer_factory(net.parameters())
for _ in range(2):
opt.zero_grad()
out = net(inputs)
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
loss.backward()
opt.step()
self._check_results(
opt,
self._flat_out_extrafields(
p.profiler.kineto_results.experimental_event_tree()
),
True,
)
def test_optimizer_parameters_sgd(self):
self._test_optimizer_parameters(
lambda params: torch.optim.SGD(params, lr=0.01, momentum=0.9)
)
def test_optimizer_parameters_adam(self):
self._test_optimizer_parameters(
lambda params: torch.optim.Adam(params, foreach=True)
)
def test_allocations(self):
gc.collect()
with profile(profile_memory=True) as p:
x = torch.empty((3, 4))
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "[memory]")
self.assertIsNotNone(node)
alloc_size = 3 * 4 * 4 # fp32 -> 4 bytes
ptr = node.extra_fields.ptr
self.assertGreater(ptr, 0)
self.assertEqual(node.extra_fields.alloc_size, alloc_size)
self.assertEqual(node.extra_fields.device, torch.device("cpu"))
total_allocated = node.extra_fields.total_allocated
# total_reserved is only for CUDACachingAllocator
self.assertEqual(node.extra_fields.total_reserved, 0)
with profile(profile_memory=True) as p:
del x
gc.collect()
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "[memory]")
self.assertIsNotNone(node)
self.assertEqual(node.extra_fields.ptr, ptr)
self.assertEqual(node.extra_fields.alloc_size, -alloc_size)
self.assertEqual(node.extra_fields.device, torch.device("cpu"))
self.assertEqual(
node.extra_fields.total_allocated, total_allocated - alloc_size
)
def test_refcounts(self):
class Sentinel:
pass
def make():
outer_sentinel = Sentinel()
def outer():
# Python will only close over variables used in the function.
_ = outer_sentinel
inner_sentinel = Sentinel()
def inner():
_ = inner_sentinel
with profile(with_stack=True):
inner()
return weakref.ref(inner_sentinel)
return outer, weakref.ref(outer_sentinel)
# Use a factory function to ensure the test scope never sees strong
# references. `del` has strange semantics that interact with closures
# at an AST level, so this is simpler.
outer, outer_sentinel_ref = make()
inner_sentinel_ref = outer()
self.assertIsNone(inner_sentinel_ref())
# `outer` holds the last reference via closure.
self.assertIsNotNone(outer_sentinel_ref())
del outer
self.assertIsNone(outer_sentinel_ref())
if __name__ == "__main__":
run_tests()