[memory debugging] Extract frame information from inductor (#95753)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95753
Approved by: https://github.com/Chillee
This commit is contained in:
Zachary DeVito
2023-03-15 15:14:10 -07:00
committed by PyTorch MergeBot
parent e74f70d212
commit 3162f71787
7 changed files with 127 additions and 12 deletions

View File

@ -7065,6 +7065,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertTrue(torch.allclose(module(input), traced(input)))
def test_memory_history_inductor(self):
def called_inside_compile(x, w, b):
a = x @ w + b
return torch.sigmoid(a)
@torch.compile
def fn(x, w, b):
x = called_inside_compile(x, w, b)
return called_inside_compile(x, w, b)
w = torch.rand(3, 3, device="cuda")
b = torch.rand(3, device="cuda")
x = torch.rand(3, device="cuda")
try:
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history(True)
r = fn(x, w, b)
finally:
torch.cuda.memory._record_memory_history(False)
snapshot = str(torch.cuda.memory._snapshot())
self.assertTrue("called_inside_compile" in snapshot)
copy_tests(CommonTemplate, CudaTests, "cuda")
class CudaReproTests(TestCase):

View File

@ -15,6 +15,7 @@ import sys
import sysconfig
import tempfile
import types
from bisect import bisect_right
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from ctypes import cdll
from functools import partial
@ -635,10 +636,11 @@ class CppCodeCache:
class PyCodeCache:
cache = dict()
linemaps = dict()
clear = staticmethod(cache.clear)
@classmethod
def load(cls, source_code, extra=""):
def load(cls, source_code, extra="", linemap=()):
key, path = write(source_code, "py", extra)
if key not in cls.cache:
with open(path) as f:
@ -654,8 +656,36 @@ class PyCodeCache:
exec(code, mod.__dict__, mod.__dict__)
# another thread might set this first
cls.cache.setdefault(key, mod)
cls.linemaps[path] = linemap
return cls.cache[key]
@classmethod
@functools.lru_cache(None)
def stack_frames_for_code(cls, path, lineno):
if path not in cls.linemaps:
return None
# [(starting_line, <fx node>), ...]
linemap = cls.linemaps[path]
p = bisect_right(linemap, lineno, key=lambda x: x[0])
if p == 0:
return None
_, entry = linemap[p - 1]
if not entry:
return None
def parse_stack_trace(stack_trace):
# ideally fx stores stack traces as data rather than a string
# but this is not along a performance critical path
regex = r'File "(.+)", line (\d+), in (.+)\n'
matches = re.findall(regex, stack_trace)
return [
{"filename": f, "line": int(l), "name": n}
for f, l, n in reversed(matches)
]
return parse_stack_trace(entry.stack_trace)
class TritonCodeCache:
@staticmethod

View File

@ -17,6 +17,7 @@ from ..utils import (
cache_on_self,
get_benchmark_name,
has_triton,
LineContext,
sympy_dot,
sympy_product,
sympy_symbol,
@ -562,7 +563,7 @@ class WrapperCodeGen(CodeGen):
self.add_benchmark_harness(result)
return result.getvalue()
return result.getvaluewithlinemap()
def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
"""Assign all symbolic shapes to locals"""
@ -736,6 +737,9 @@ class WrapperCodeGen(CodeGen):
def writeline(self, line):
self.lines.append(line)
def enter_context(self, ctx):
self.lines.append(LineContext(ctx))
class CppWrapperCodeGen(WrapperCodeGen):
"""

View File

@ -619,11 +619,11 @@ class GraphLowering(torch.fx.Interpreter):
def compile_to_module(self):
from .codecache import PyCodeCache
code = self.codegen()
code, linemap = self.codegen()
if config.debug:
print(code)
mod = PyCodeCache.load(code)
mod = PyCodeCache.load(code, linemap=linemap)
for name, value in self.constants.items():
setattr(mod, name, value)

View File

@ -674,6 +674,10 @@ class Scheduler:
self.buffer_names_to_free = set()
self.buffer_names_no_longer_needed = set()
# fx graph node to the position it appears in the graph
# for debug attribution
self.origin_to_index = {}
def debug_draw_graph(self):
"""Generate an image of the graph for debugging"""
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
@ -1164,9 +1168,21 @@ class Scheduler:
self.backends[device] = self.create_backend(device)
return self.backends[device]
def enter_context(self, node):
def get_order(n):
if n not in self.origin_to_index:
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
return self.origin_to_index[n]
origins = [(get_order(e), e) for n in node.get_nodes() for e in n.node.origins]
if origins:
_, last = max(origins)
V.graph.wrapper_code.enter_context(last)
@dynamo_timed
def codegen(self):
for node in self.nodes:
self.enter_context(node)
self.buffer_names_no_longer_needed.update(node.last_usage)
if not isinstance(node, NopKernelSchedulerNode):

View File

@ -12,7 +12,7 @@ import tempfile
import textwrap
import time
from io import StringIO
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional, Union
from unittest import mock
import sympy
@ -420,6 +420,10 @@ def get_dtype_size(dtype):
return torch.empty((), dtype=dtype).element_size()
class LineContext(NamedTuple):
context: Any
class IndentedBuffer:
tabwidth = 4
@ -427,19 +431,27 @@ class IndentedBuffer:
self._lines = []
self._indent = initial_indent
def getvalue(
self,
):
def getvaluewithlinemap(self):
buf = StringIO()
p = 1
linemap = []
for line in self._lines:
if isinstance(line, DeferredLineBase):
line = line()
if line is None:
continue
elif isinstance(line, LineContext):
linemap.append((p, line.context))
continue
assert isinstance(line, str)
buf.write(line)
buf.write("\n")
return buf.getvalue()
p += 1 + line.count("\n")
return buf.getvalue(), linemap
def getvalue(self):
v, _ = self.getvaluewithlinemap()
return v
def getrawvalue(self):
buf = StringIO()
@ -448,6 +460,8 @@ class IndentedBuffer:
line = line()
if line is None:
continue
elif isinstance(line, LineContext):
continue
assert isinstance(line, str)
# backslash implies line continuation
if line.endswith("\\"):
@ -467,7 +481,9 @@ class IndentedBuffer:
return " " * (self._indent * self.tabwidth)
def writeline(self, line):
if isinstance(line, DeferredLineBase):
if isinstance(line, LineContext):
self._lines.append(line)
elif isinstance(line, DeferredLineBase):
self._lines.append(line.with_prefix(self.prefix()))
elif line.strip():
self._lines.append(f"{self.prefix()}{line}")
@ -493,12 +509,15 @@ class IndentedBuffer:
if isinstance(other_code, IndentedBuffer):
dedent = float("inf")
for line in other_code._lines:
if line:
if not isinstance(line, LineContext) and line:
dedent = min(dedent, len(line) - len(line.lstrip()))
if math.isinf(dedent):
dedent = 0
for line in other_code._lines:
IndentedBuffer.writeline(self, line[dedent:])
if isinstance(line, LineContext):
self._lines.append(line)
else:
IndentedBuffer.writeline(self, line[dedent:])
else:
other_code = textwrap.dedent(other_code)
if strip:

View File

@ -56,6 +56,16 @@ struct PythonTraceback : public CapturedTraceback::Python {
py::str name_s = "name";
py::str filename_s = "filename";
auto torch = py::module::import("torch");
py::object stack_frames_for_code;
if (py::hasattr(torch, "_inductor")) {
py::object inductor = torch.attr("_inductor");
if (py::hasattr(inductor, "codecache")) {
stack_frames_for_code = inductor.attr("codecache")
.attr("PyCodeCache")
.attr("stack_frames_for_code");
}
}
for (const auto& f : to_symbolize) {
auto f_code = (PyCodeObject*)f.code;
py::handle filename = f_code->co_filename;
@ -67,6 +77,20 @@ struct PythonTraceback : public CapturedTraceback::Python {
py::cast<std::string>(filename),
py::cast<std::string>(funcname),
(uint64_t)lineno});
// find all the additional frames associated with inductor generated
// code
if (stack_frames_for_code.ptr()) {
py::object extra = stack_frames_for_code(filename, lineno);
if (!extra.is_none()) {
for (py::handle h : extra) {
result.tracebacks.back().push_back(result.all_frames.size());
result.all_frames.emplace_back(unwind::Frame{
py::cast<std::string>(h[filename_s]),
py::cast<std::string>(h[name_s]),
py::cast<uint64_t>(h[line_s])});
}
}
}
}
}
};