mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[memory debugging] Extract frame information from inductor (#95753)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95753 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
e74f70d212
commit
3162f71787
@ -7065,6 +7065,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
|
||||
self.assertTrue(torch.allclose(module(input), traced(input)))
|
||||
|
||||
def test_memory_history_inductor(self):
|
||||
def called_inside_compile(x, w, b):
|
||||
a = x @ w + b
|
||||
return torch.sigmoid(a)
|
||||
|
||||
@torch.compile
|
||||
def fn(x, w, b):
|
||||
x = called_inside_compile(x, w, b)
|
||||
return called_inside_compile(x, w, b)
|
||||
|
||||
w = torch.rand(3, 3, device="cuda")
|
||||
b = torch.rand(3, device="cuda")
|
||||
x = torch.rand(3, device="cuda")
|
||||
try:
|
||||
torch.cuda.memory.empty_cache()
|
||||
torch.cuda.memory._record_memory_history(True)
|
||||
r = fn(x, w, b)
|
||||
finally:
|
||||
torch.cuda.memory._record_memory_history(False)
|
||||
snapshot = str(torch.cuda.memory._snapshot())
|
||||
self.assertTrue("called_inside_compile" in snapshot)
|
||||
|
||||
copy_tests(CommonTemplate, CudaTests, "cuda")
|
||||
|
||||
class CudaReproTests(TestCase):
|
||||
|
@ -15,6 +15,7 @@ import sys
|
||||
import sysconfig
|
||||
import tempfile
|
||||
import types
|
||||
from bisect import bisect_right
|
||||
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from ctypes import cdll
|
||||
from functools import partial
|
||||
@ -635,10 +636,11 @@ class CppCodeCache:
|
||||
|
||||
class PyCodeCache:
|
||||
cache = dict()
|
||||
linemaps = dict()
|
||||
clear = staticmethod(cache.clear)
|
||||
|
||||
@classmethod
|
||||
def load(cls, source_code, extra=""):
|
||||
def load(cls, source_code, extra="", linemap=()):
|
||||
key, path = write(source_code, "py", extra)
|
||||
if key not in cls.cache:
|
||||
with open(path) as f:
|
||||
@ -654,8 +656,36 @@ class PyCodeCache:
|
||||
exec(code, mod.__dict__, mod.__dict__)
|
||||
# another thread might set this first
|
||||
cls.cache.setdefault(key, mod)
|
||||
cls.linemaps[path] = linemap
|
||||
|
||||
return cls.cache[key]
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
def stack_frames_for_code(cls, path, lineno):
|
||||
if path not in cls.linemaps:
|
||||
return None
|
||||
# [(starting_line, <fx node>), ...]
|
||||
linemap = cls.linemaps[path]
|
||||
p = bisect_right(linemap, lineno, key=lambda x: x[0])
|
||||
if p == 0:
|
||||
return None
|
||||
_, entry = linemap[p - 1]
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
def parse_stack_trace(stack_trace):
|
||||
# ideally fx stores stack traces as data rather than a string
|
||||
# but this is not along a performance critical path
|
||||
regex = r'File "(.+)", line (\d+), in (.+)\n'
|
||||
matches = re.findall(regex, stack_trace)
|
||||
return [
|
||||
{"filename": f, "line": int(l), "name": n}
|
||||
for f, l, n in reversed(matches)
|
||||
]
|
||||
|
||||
return parse_stack_trace(entry.stack_trace)
|
||||
|
||||
|
||||
class TritonCodeCache:
|
||||
@staticmethod
|
||||
|
@ -17,6 +17,7 @@ from ..utils import (
|
||||
cache_on_self,
|
||||
get_benchmark_name,
|
||||
has_triton,
|
||||
LineContext,
|
||||
sympy_dot,
|
||||
sympy_product,
|
||||
sympy_symbol,
|
||||
@ -562,7 +563,7 @@ class WrapperCodeGen(CodeGen):
|
||||
|
||||
self.add_benchmark_harness(result)
|
||||
|
||||
return result.getvalue()
|
||||
return result.getvaluewithlinemap()
|
||||
|
||||
def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
|
||||
"""Assign all symbolic shapes to locals"""
|
||||
@ -736,6 +737,9 @@ class WrapperCodeGen(CodeGen):
|
||||
def writeline(self, line):
|
||||
self.lines.append(line)
|
||||
|
||||
def enter_context(self, ctx):
|
||||
self.lines.append(LineContext(ctx))
|
||||
|
||||
|
||||
class CppWrapperCodeGen(WrapperCodeGen):
|
||||
"""
|
||||
|
@ -619,11 +619,11 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
def compile_to_module(self):
|
||||
from .codecache import PyCodeCache
|
||||
|
||||
code = self.codegen()
|
||||
code, linemap = self.codegen()
|
||||
if config.debug:
|
||||
print(code)
|
||||
|
||||
mod = PyCodeCache.load(code)
|
||||
mod = PyCodeCache.load(code, linemap=linemap)
|
||||
for name, value in self.constants.items():
|
||||
setattr(mod, name, value)
|
||||
|
||||
|
@ -674,6 +674,10 @@ class Scheduler:
|
||||
self.buffer_names_to_free = set()
|
||||
self.buffer_names_no_longer_needed = set()
|
||||
|
||||
# fx graph node to the position it appears in the graph
|
||||
# for debug attribution
|
||||
self.origin_to_index = {}
|
||||
|
||||
def debug_draw_graph(self):
|
||||
"""Generate an image of the graph for debugging"""
|
||||
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
|
||||
@ -1164,9 +1168,21 @@ class Scheduler:
|
||||
self.backends[device] = self.create_backend(device)
|
||||
return self.backends[device]
|
||||
|
||||
def enter_context(self, node):
|
||||
def get_order(n):
|
||||
if n not in self.origin_to_index:
|
||||
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
|
||||
return self.origin_to_index[n]
|
||||
|
||||
origins = [(get_order(e), e) for n in node.get_nodes() for e in n.node.origins]
|
||||
if origins:
|
||||
_, last = max(origins)
|
||||
V.graph.wrapper_code.enter_context(last)
|
||||
|
||||
@dynamo_timed
|
||||
def codegen(self):
|
||||
for node in self.nodes:
|
||||
self.enter_context(node)
|
||||
self.buffer_names_no_longer_needed.update(node.last_usage)
|
||||
|
||||
if not isinstance(node, NopKernelSchedulerNode):
|
||||
|
@ -12,7 +12,7 @@ import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import sympy
|
||||
@ -420,6 +420,10 @@ def get_dtype_size(dtype):
|
||||
return torch.empty((), dtype=dtype).element_size()
|
||||
|
||||
|
||||
class LineContext(NamedTuple):
|
||||
context: Any
|
||||
|
||||
|
||||
class IndentedBuffer:
|
||||
tabwidth = 4
|
||||
|
||||
@ -427,19 +431,27 @@ class IndentedBuffer:
|
||||
self._lines = []
|
||||
self._indent = initial_indent
|
||||
|
||||
def getvalue(
|
||||
self,
|
||||
):
|
||||
def getvaluewithlinemap(self):
|
||||
buf = StringIO()
|
||||
p = 1
|
||||
linemap = []
|
||||
for line in self._lines:
|
||||
if isinstance(line, DeferredLineBase):
|
||||
line = line()
|
||||
if line is None:
|
||||
continue
|
||||
elif isinstance(line, LineContext):
|
||||
linemap.append((p, line.context))
|
||||
continue
|
||||
assert isinstance(line, str)
|
||||
buf.write(line)
|
||||
buf.write("\n")
|
||||
return buf.getvalue()
|
||||
p += 1 + line.count("\n")
|
||||
return buf.getvalue(), linemap
|
||||
|
||||
def getvalue(self):
|
||||
v, _ = self.getvaluewithlinemap()
|
||||
return v
|
||||
|
||||
def getrawvalue(self):
|
||||
buf = StringIO()
|
||||
@ -448,6 +460,8 @@ class IndentedBuffer:
|
||||
line = line()
|
||||
if line is None:
|
||||
continue
|
||||
elif isinstance(line, LineContext):
|
||||
continue
|
||||
assert isinstance(line, str)
|
||||
# backslash implies line continuation
|
||||
if line.endswith("\\"):
|
||||
@ -467,7 +481,9 @@ class IndentedBuffer:
|
||||
return " " * (self._indent * self.tabwidth)
|
||||
|
||||
def writeline(self, line):
|
||||
if isinstance(line, DeferredLineBase):
|
||||
if isinstance(line, LineContext):
|
||||
self._lines.append(line)
|
||||
elif isinstance(line, DeferredLineBase):
|
||||
self._lines.append(line.with_prefix(self.prefix()))
|
||||
elif line.strip():
|
||||
self._lines.append(f"{self.prefix()}{line}")
|
||||
@ -493,12 +509,15 @@ class IndentedBuffer:
|
||||
if isinstance(other_code, IndentedBuffer):
|
||||
dedent = float("inf")
|
||||
for line in other_code._lines:
|
||||
if line:
|
||||
if not isinstance(line, LineContext) and line:
|
||||
dedent = min(dedent, len(line) - len(line.lstrip()))
|
||||
if math.isinf(dedent):
|
||||
dedent = 0
|
||||
for line in other_code._lines:
|
||||
IndentedBuffer.writeline(self, line[dedent:])
|
||||
if isinstance(line, LineContext):
|
||||
self._lines.append(line)
|
||||
else:
|
||||
IndentedBuffer.writeline(self, line[dedent:])
|
||||
else:
|
||||
other_code = textwrap.dedent(other_code)
|
||||
if strip:
|
||||
|
@ -56,6 +56,16 @@ struct PythonTraceback : public CapturedTraceback::Python {
|
||||
py::str name_s = "name";
|
||||
py::str filename_s = "filename";
|
||||
|
||||
auto torch = py::module::import("torch");
|
||||
py::object stack_frames_for_code;
|
||||
if (py::hasattr(torch, "_inductor")) {
|
||||
py::object inductor = torch.attr("_inductor");
|
||||
if (py::hasattr(inductor, "codecache")) {
|
||||
stack_frames_for_code = inductor.attr("codecache")
|
||||
.attr("PyCodeCache")
|
||||
.attr("stack_frames_for_code");
|
||||
}
|
||||
}
|
||||
for (const auto& f : to_symbolize) {
|
||||
auto f_code = (PyCodeObject*)f.code;
|
||||
py::handle filename = f_code->co_filename;
|
||||
@ -67,6 +77,20 @@ struct PythonTraceback : public CapturedTraceback::Python {
|
||||
py::cast<std::string>(filename),
|
||||
py::cast<std::string>(funcname),
|
||||
(uint64_t)lineno});
|
||||
// find all the additional frames associated with inductor generated
|
||||
// code
|
||||
if (stack_frames_for_code.ptr()) {
|
||||
py::object extra = stack_frames_for_code(filename, lineno);
|
||||
if (!extra.is_none()) {
|
||||
for (py::handle h : extra) {
|
||||
result.tracebacks.back().push_back(result.all_frames.size());
|
||||
result.all_frames.emplace_back(unwind::Frame{
|
||||
py::cast<std::string>(h[filename_s]),
|
||||
py::cast<std::string>(h[name_s]),
|
||||
py::cast<uint64_t>(h[line_s])});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user