mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160573 Approved by: https://github.com/bdhirsh, https://github.com/albanD
232 lines
8.1 KiB
Python
232 lines
8.1 KiB
Python
# Owner(s): ["oncall: profiler"]
|
|
# ruff: noqa: F841
|
|
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import torch.utils.data.datapipes as dp
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch.autograd import (
|
|
_record_function_with_args_enter,
|
|
_record_function_with_args_exit,
|
|
)
|
|
from torch.autograd.profiler import profile as _profile
|
|
from torch.profiler import kineto_available, record_function
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
|
# This causes an issue in the multithreading test because we check all events
|
|
# in that test with their tids. The events that correspond to these lingering
|
|
# threads all have TID of (uint64_t)(-1) which is invalid.
|
|
# The work around is turnning off monitoring thread when tqdm is loaded.
|
|
# Since these are unit tests, it is safe to turn off monitor thread.
|
|
try:
|
|
import tqdm
|
|
|
|
tqdm.tqdm.monitor_interval = 0
|
|
except ImportError:
|
|
pass
|
|
|
|
Json = dict[str, Any]
|
|
|
|
|
|
class TestRecordFunction(TestCase):
|
|
def _record_function_with_param(self):
|
|
u = torch.randn(3, 4, 5, requires_grad=True)
|
|
with _profile(
|
|
with_stack=True, use_kineto=kineto_available(), record_shapes=True
|
|
) as prof:
|
|
with record_function("## TEST 1 ##", "1, 2, 3"):
|
|
rf_handle = _record_function_with_args_enter(
|
|
"## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u
|
|
)
|
|
_record_function_with_args_exit(rf_handle)
|
|
with record_function("## TEST 3 ##"):
|
|
rf_handle = _record_function_with_args_enter("## TEST 4 ##")
|
|
_record_function_with_args_exit(rf_handle)
|
|
return prof
|
|
|
|
def test_record_function(self):
|
|
prof_result = self._record_function_with_param()
|
|
found_test_1 = False
|
|
found_test_2 = False
|
|
found_test_3 = False
|
|
found_test_4 = False
|
|
for e in prof_result.function_events:
|
|
if "## TEST 1 ##" == e.name:
|
|
found_test_1 = True
|
|
self.assertTrue(e.input_shapes == [[]])
|
|
elif "## TEST 2 ##" == e.name:
|
|
found_test_2 = True
|
|
self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]])
|
|
elif "## TEST 3 ##" == e.name:
|
|
found_test_3 = True
|
|
self.assertTrue(e.input_shapes == [])
|
|
elif "## TEST 4 ##" == e.name:
|
|
found_test_4 = True
|
|
self.assertTrue(e.input_shapes == [])
|
|
self.assertTrue(found_test_1)
|
|
self.assertTrue(found_test_2)
|
|
self.assertTrue(found_test_3)
|
|
self.assertTrue(found_test_4)
|
|
|
|
def test_datapipe_with_record_function(self):
|
|
with _profile(
|
|
with_stack=True, use_kineto=kineto_available(), record_shapes=True
|
|
) as prof:
|
|
input_dp1 = dp.iter.IterableWrapper(range(4))
|
|
input_dp2 = dp.iter.IterableWrapper(range(4, 8))
|
|
input_dp3 = dp.iter.IterableWrapper(range(8, 12))
|
|
output_dp = input_dp1.mux(input_dp2, input_dp3)
|
|
output = list(output_dp)
|
|
|
|
has_iter = False
|
|
has_mux = False
|
|
for e in prof.function_events:
|
|
if has_iter and has_mux:
|
|
break
|
|
|
|
if not has_iter and "IterableWrapper" in e.name:
|
|
has_iter = True
|
|
if not has_mux and "Multiplexer" in e.name:
|
|
has_mux = True
|
|
self.assertTrue(has_iter)
|
|
self.assertTrue(has_mux)
|
|
|
|
def test_datapipe_delegation_with_profiler(self):
|
|
class IDPIterator(torch.utils.data.IterDataPipe):
|
|
def __init__(self) -> None:
|
|
self.data = list(range(10))
|
|
self._idx = 0
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self._idx >= 10:
|
|
self._idx = 0
|
|
raise StopIteration
|
|
self._idx += 1
|
|
return self.data[self._idx - 1]
|
|
|
|
def get_value(self, idx):
|
|
return self.data[idx]
|
|
|
|
dp1 = IDPIterator() # The object itself is an iterator
|
|
self.assertEqual(5, dp1.get_value(5))
|
|
it_dp1 = iter(dp1) # This creates the 1st iterator
|
|
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
|
|
self.assertEqual(list(range(10)), list(it_dp1))
|
|
|
|
class IDPDelegator(torch.utils.data.IterDataPipe):
|
|
def __init__(self, datapipe):
|
|
self.datapipe = datapipe
|
|
|
|
def __iter__(self):
|
|
return iter(self.datapipe)
|
|
|
|
dp2 = IDPDelegator(dp1)
|
|
it_dp2 = iter(dp2)
|
|
self.assertEqual(5, it_dp2.get_value(5))
|
|
self.assertEqual(list(range(10)), list(it_dp2))
|
|
|
|
def test_datapipe_with_record_function_fork(self):
|
|
with _profile(
|
|
with_stack=True, use_kineto=kineto_available(), record_shapes=True
|
|
) as prof:
|
|
input_dp = dp.iter.IterableWrapper(range(10))
|
|
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
|
|
output1 = list(dp1)
|
|
has_iter = False
|
|
has_child = False
|
|
for e in prof.function_events:
|
|
if has_iter and has_child:
|
|
break
|
|
|
|
if not has_iter and "IterableWrapper" in e.name:
|
|
has_iter = True
|
|
if not has_child and "_ChildDataPipe" in e.name:
|
|
has_child = True
|
|
self.assertTrue(has_iter)
|
|
self.assertTrue(has_child)
|
|
|
|
def test_python_dispatch_mode_record_function(self):
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
class TestDispatchMode(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
with _profile() as prof:
|
|
with enable_python_dispatcher():
|
|
with TestDispatchMode():
|
|
x = torch.randn(3, 4)
|
|
y = torch.sin(x)
|
|
|
|
found_python_dispatch_mode = False
|
|
for e in prof.function_events:
|
|
if e.name == "PythonDispatchMode":
|
|
found_python_dispatch_mode = True
|
|
break
|
|
self.assertTrue(
|
|
found_python_dispatch_mode,
|
|
"PythonDispatchMode record function not found in profiler events",
|
|
)
|
|
|
|
def test_python_subclass_record_function(self):
|
|
class TestTensorSubclass(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
elem.size(),
|
|
dtype=elem.dtype,
|
|
device=elem.device,
|
|
requires_grad=elem.requires_grad,
|
|
)
|
|
r.elem = elem
|
|
return r
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
def unwrap(x):
|
|
return x.elem if isinstance(x, TestTensorSubclass) else x
|
|
|
|
def wrap(x):
|
|
return TestTensorSubclass(x) if isinstance(x, torch.Tensor) else x
|
|
|
|
unwrapped_args = tuple(unwrap(arg) for arg in args)
|
|
unwrapped_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
|
|
result = func(*unwrapped_args, **unwrapped_kwargs)
|
|
|
|
if isinstance(result, torch.Tensor):
|
|
return TestTensorSubclass(result)
|
|
return result
|
|
|
|
with _profile() as prof:
|
|
with enable_python_dispatcher():
|
|
x = TestTensorSubclass(torch.randn(3, 4))
|
|
y = torch.sin(x)
|
|
|
|
found_python_subclass = False
|
|
for e in prof.function_events:
|
|
if e.name == "PythonSubclass":
|
|
found_python_subclass = True
|
|
break
|
|
self.assertTrue(
|
|
found_python_subclass,
|
|
"PythonSubclass record function not found in profiler events",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|