Files
pytorch/test/profiler/test_record_function.py

232 lines
8.1 KiB
Python

# Owner(s): ["oncall: profiler"]
# ruff: noqa: F841
from typing import Any
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.datapipes as dp
from torch._dispatch.python import enable_python_dispatcher
from torch.autograd import (
_record_function_with_args_enter,
_record_function_with_args_exit,
)
from torch.autograd.profiler import profile as _profile
from torch.profiler import kineto_available, record_function
from torch.testing._internal.common_utils import run_tests, TestCase
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
pass
Json = dict[str, Any]
class TestRecordFunction(TestCase):
def _record_function_with_param(self):
u = torch.randn(3, 4, 5, requires_grad=True)
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
with record_function("## TEST 1 ##", "1, 2, 3"):
rf_handle = _record_function_with_args_enter(
"## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u
)
_record_function_with_args_exit(rf_handle)
with record_function("## TEST 3 ##"):
rf_handle = _record_function_with_args_enter("## TEST 4 ##")
_record_function_with_args_exit(rf_handle)
return prof
def test_record_function(self):
prof_result = self._record_function_with_param()
found_test_1 = False
found_test_2 = False
found_test_3 = False
found_test_4 = False
for e in prof_result.function_events:
if "## TEST 1 ##" == e.name:
found_test_1 = True
self.assertTrue(e.input_shapes == [[]])
elif "## TEST 2 ##" == e.name:
found_test_2 = True
self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]])
elif "## TEST 3 ##" == e.name:
found_test_3 = True
self.assertTrue(e.input_shapes == [])
elif "## TEST 4 ##" == e.name:
found_test_4 = True
self.assertTrue(e.input_shapes == [])
self.assertTrue(found_test_1)
self.assertTrue(found_test_2)
self.assertTrue(found_test_3)
self.assertTrue(found_test_4)
def test_datapipe_with_record_function(self):
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
input_dp1 = dp.iter.IterableWrapper(range(4))
input_dp2 = dp.iter.IterableWrapper(range(4, 8))
input_dp3 = dp.iter.IterableWrapper(range(8, 12))
output_dp = input_dp1.mux(input_dp2, input_dp3)
output = list(output_dp)
has_iter = False
has_mux = False
for e in prof.function_events:
if has_iter and has_mux:
break
if not has_iter and "IterableWrapper" in e.name:
has_iter = True
if not has_mux and "Multiplexer" in e.name:
has_mux = True
self.assertTrue(has_iter)
self.assertTrue(has_mux)
def test_datapipe_delegation_with_profiler(self):
class IDPIterator(torch.utils.data.IterDataPipe):
def __init__(self) -> None:
self.data = list(range(10))
self._idx = 0
def __iter__(self):
return self
def __next__(self):
if self._idx >= 10:
self._idx = 0
raise StopIteration
self._idx += 1
return self.data[self._idx - 1]
def get_value(self, idx):
return self.data[idx]
dp1 = IDPIterator() # The object itself is an iterator
self.assertEqual(5, dp1.get_value(5))
it_dp1 = iter(dp1) # This creates the 1st iterator
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
self.assertEqual(list(range(10)), list(it_dp1))
class IDPDelegator(torch.utils.data.IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
return iter(self.datapipe)
dp2 = IDPDelegator(dp1)
it_dp2 = iter(dp2)
self.assertEqual(5, it_dp2.get_value(5))
self.assertEqual(list(range(10)), list(it_dp2))
def test_datapipe_with_record_function_fork(self):
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
input_dp = dp.iter.IterableWrapper(range(10))
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
output1 = list(dp1)
has_iter = False
has_child = False
for e in prof.function_events:
if has_iter and has_child:
break
if not has_iter and "IterableWrapper" in e.name:
has_iter = True
if not has_child and "_ChildDataPipe" in e.name:
has_child = True
self.assertTrue(has_iter)
self.assertTrue(has_child)
def test_python_dispatch_mode_record_function(self):
from torch.utils._python_dispatch import TorchDispatchMode
class TestDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
with _profile() as prof:
with enable_python_dispatcher():
with TestDispatchMode():
x = torch.randn(3, 4)
y = torch.sin(x)
found_python_dispatch_mode = False
for e in prof.function_events:
if e.name == "PythonDispatchMode":
found_python_dispatch_mode = True
break
self.assertTrue(
found_python_dispatch_mode,
"PythonDispatchMode record function not found in profiler events",
)
def test_python_subclass_record_function(self):
class TestTensorSubclass(torch.Tensor):
@staticmethod
def __new__(cls, elem):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
dtype=elem.dtype,
device=elem.device,
requires_grad=elem.requires_grad,
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def unwrap(x):
return x.elem if isinstance(x, TestTensorSubclass) else x
def wrap(x):
return TestTensorSubclass(x) if isinstance(x, torch.Tensor) else x
unwrapped_args = tuple(unwrap(arg) for arg in args)
unwrapped_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
result = func(*unwrapped_args, **unwrapped_kwargs)
if isinstance(result, torch.Tensor):
return TestTensorSubclass(result)
return result
with _profile() as prof:
with enable_python_dispatcher():
x = TestTensorSubclass(torch.randn(3, 4))
y = torch.sin(x)
found_python_subclass = False
for e in prof.function_events:
if e.name == "PythonSubclass":
found_python_subclass = True
break
self.assertTrue(
found_python_subclass,
"PythonSubclass record function not found in profiler events",
)
if __name__ == "__main__":
run_tests()