[memory debugging] Extract frame information from inductor (#95753)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95753
Approved by: https://github.com/Chillee
This commit is contained in:
Zachary DeVito
2023-03-15 15:14:10 -07:00
committed by PyTorch MergeBot
parent e74f70d212
commit 3162f71787
7 changed files with 127 additions and 12 deletions

View File

@ -7065,6 +7065,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertTrue(torch.allclose(module(input), traced(input))) self.assertTrue(torch.allclose(module(input), traced(input)))
def test_memory_history_inductor(self):
def called_inside_compile(x, w, b):
a = x @ w + b
return torch.sigmoid(a)
@torch.compile
def fn(x, w, b):
x = called_inside_compile(x, w, b)
return called_inside_compile(x, w, b)
w = torch.rand(3, 3, device="cuda")
b = torch.rand(3, device="cuda")
x = torch.rand(3, device="cuda")
try:
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history(True)
r = fn(x, w, b)
finally:
torch.cuda.memory._record_memory_history(False)
snapshot = str(torch.cuda.memory._snapshot())
self.assertTrue("called_inside_compile" in snapshot)
copy_tests(CommonTemplate, CudaTests, "cuda") copy_tests(CommonTemplate, CudaTests, "cuda")
class CudaReproTests(TestCase): class CudaReproTests(TestCase):

View File

@ -15,6 +15,7 @@ import sys
import sysconfig import sysconfig
import tempfile import tempfile
import types import types
from bisect import bisect_right
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from ctypes import cdll from ctypes import cdll
from functools import partial from functools import partial
@ -635,10 +636,11 @@ class CppCodeCache:
class PyCodeCache: class PyCodeCache:
cache = dict() cache = dict()
linemaps = dict()
clear = staticmethod(cache.clear) clear = staticmethod(cache.clear)
@classmethod @classmethod
def load(cls, source_code, extra=""): def load(cls, source_code, extra="", linemap=()):
key, path = write(source_code, "py", extra) key, path = write(source_code, "py", extra)
if key not in cls.cache: if key not in cls.cache:
with open(path) as f: with open(path) as f:
@ -654,8 +656,36 @@ class PyCodeCache:
exec(code, mod.__dict__, mod.__dict__) exec(code, mod.__dict__, mod.__dict__)
# another thread might set this first # another thread might set this first
cls.cache.setdefault(key, mod) cls.cache.setdefault(key, mod)
cls.linemaps[path] = linemap
return cls.cache[key] return cls.cache[key]
@classmethod
@functools.lru_cache(None)
def stack_frames_for_code(cls, path, lineno):
if path not in cls.linemaps:
return None
# [(starting_line, <fx node>), ...]
linemap = cls.linemaps[path]
p = bisect_right(linemap, lineno, key=lambda x: x[0])
if p == 0:
return None
_, entry = linemap[p - 1]
if not entry:
return None
def parse_stack_trace(stack_trace):
# ideally fx stores stack traces as data rather than a string
# but this is not along a performance critical path
regex = r'File "(.+)", line (\d+), in (.+)\n'
matches = re.findall(regex, stack_trace)
return [
{"filename": f, "line": int(l), "name": n}
for f, l, n in reversed(matches)
]
return parse_stack_trace(entry.stack_trace)
class TritonCodeCache: class TritonCodeCache:
@staticmethod @staticmethod

View File

@ -17,6 +17,7 @@ from ..utils import (
cache_on_self, cache_on_self,
get_benchmark_name, get_benchmark_name,
has_triton, has_triton,
LineContext,
sympy_dot, sympy_dot,
sympy_product, sympy_product,
sympy_symbol, sympy_symbol,
@ -562,7 +563,7 @@ class WrapperCodeGen(CodeGen):
self.add_benchmark_harness(result) self.add_benchmark_harness(result)
return result.getvalue() return result.getvaluewithlinemap()
def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]): def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
"""Assign all symbolic shapes to locals""" """Assign all symbolic shapes to locals"""
@ -736,6 +737,9 @@ class WrapperCodeGen(CodeGen):
def writeline(self, line): def writeline(self, line):
self.lines.append(line) self.lines.append(line)
def enter_context(self, ctx):
self.lines.append(LineContext(ctx))
class CppWrapperCodeGen(WrapperCodeGen): class CppWrapperCodeGen(WrapperCodeGen):
""" """

View File

@ -619,11 +619,11 @@ class GraphLowering(torch.fx.Interpreter):
def compile_to_module(self): def compile_to_module(self):
from .codecache import PyCodeCache from .codecache import PyCodeCache
code = self.codegen() code, linemap = self.codegen()
if config.debug: if config.debug:
print(code) print(code)
mod = PyCodeCache.load(code) mod = PyCodeCache.load(code, linemap=linemap)
for name, value in self.constants.items(): for name, value in self.constants.items():
setattr(mod, name, value) setattr(mod, name, value)

View File

@ -674,6 +674,10 @@ class Scheduler:
self.buffer_names_to_free = set() self.buffer_names_to_free = set()
self.buffer_names_no_longer_needed = set() self.buffer_names_no_longer_needed = set()
# fx graph node to the position it appears in the graph
# for debug attribution
self.origin_to_index = {}
def debug_draw_graph(self): def debug_draw_graph(self):
"""Generate an image of the graph for debugging""" """Generate an image of the graph for debugging"""
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
@ -1164,9 +1168,21 @@ class Scheduler:
self.backends[device] = self.create_backend(device) self.backends[device] = self.create_backend(device)
return self.backends[device] return self.backends[device]
def enter_context(self, node):
def get_order(n):
if n not in self.origin_to_index:
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
return self.origin_to_index[n]
origins = [(get_order(e), e) for n in node.get_nodes() for e in n.node.origins]
if origins:
_, last = max(origins)
V.graph.wrapper_code.enter_context(last)
@dynamo_timed @dynamo_timed
def codegen(self): def codegen(self):
for node in self.nodes: for node in self.nodes:
self.enter_context(node)
self.buffer_names_no_longer_needed.update(node.last_usage) self.buffer_names_no_longer_needed.update(node.last_usage)
if not isinstance(node, NopKernelSchedulerNode): if not isinstance(node, NopKernelSchedulerNode):

View File

@ -12,7 +12,7 @@ import tempfile
import textwrap import textwrap
import time import time
from io import StringIO from io import StringIO
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, NamedTuple, Optional, Union
from unittest import mock from unittest import mock
import sympy import sympy
@ -420,6 +420,10 @@ def get_dtype_size(dtype):
return torch.empty((), dtype=dtype).element_size() return torch.empty((), dtype=dtype).element_size()
class LineContext(NamedTuple):
context: Any
class IndentedBuffer: class IndentedBuffer:
tabwidth = 4 tabwidth = 4
@ -427,19 +431,27 @@ class IndentedBuffer:
self._lines = [] self._lines = []
self._indent = initial_indent self._indent = initial_indent
def getvalue( def getvaluewithlinemap(self):
self,
):
buf = StringIO() buf = StringIO()
p = 1
linemap = []
for line in self._lines: for line in self._lines:
if isinstance(line, DeferredLineBase): if isinstance(line, DeferredLineBase):
line = line() line = line()
if line is None: if line is None:
continue continue
elif isinstance(line, LineContext):
linemap.append((p, line.context))
continue
assert isinstance(line, str) assert isinstance(line, str)
buf.write(line) buf.write(line)
buf.write("\n") buf.write("\n")
return buf.getvalue() p += 1 + line.count("\n")
return buf.getvalue(), linemap
def getvalue(self):
v, _ = self.getvaluewithlinemap()
return v
def getrawvalue(self): def getrawvalue(self):
buf = StringIO() buf = StringIO()
@ -448,6 +460,8 @@ class IndentedBuffer:
line = line() line = line()
if line is None: if line is None:
continue continue
elif isinstance(line, LineContext):
continue
assert isinstance(line, str) assert isinstance(line, str)
# backslash implies line continuation # backslash implies line continuation
if line.endswith("\\"): if line.endswith("\\"):
@ -467,7 +481,9 @@ class IndentedBuffer:
return " " * (self._indent * self.tabwidth) return " " * (self._indent * self.tabwidth)
def writeline(self, line): def writeline(self, line):
if isinstance(line, DeferredLineBase): if isinstance(line, LineContext):
self._lines.append(line)
elif isinstance(line, DeferredLineBase):
self._lines.append(line.with_prefix(self.prefix())) self._lines.append(line.with_prefix(self.prefix()))
elif line.strip(): elif line.strip():
self._lines.append(f"{self.prefix()}{line}") self._lines.append(f"{self.prefix()}{line}")
@ -493,12 +509,15 @@ class IndentedBuffer:
if isinstance(other_code, IndentedBuffer): if isinstance(other_code, IndentedBuffer):
dedent = float("inf") dedent = float("inf")
for line in other_code._lines: for line in other_code._lines:
if line: if not isinstance(line, LineContext) and line:
dedent = min(dedent, len(line) - len(line.lstrip())) dedent = min(dedent, len(line) - len(line.lstrip()))
if math.isinf(dedent): if math.isinf(dedent):
dedent = 0 dedent = 0
for line in other_code._lines: for line in other_code._lines:
IndentedBuffer.writeline(self, line[dedent:]) if isinstance(line, LineContext):
self._lines.append(line)
else:
IndentedBuffer.writeline(self, line[dedent:])
else: else:
other_code = textwrap.dedent(other_code) other_code = textwrap.dedent(other_code)
if strip: if strip:

View File

@ -56,6 +56,16 @@ struct PythonTraceback : public CapturedTraceback::Python {
py::str name_s = "name"; py::str name_s = "name";
py::str filename_s = "filename"; py::str filename_s = "filename";
auto torch = py::module::import("torch");
py::object stack_frames_for_code;
if (py::hasattr(torch, "_inductor")) {
py::object inductor = torch.attr("_inductor");
if (py::hasattr(inductor, "codecache")) {
stack_frames_for_code = inductor.attr("codecache")
.attr("PyCodeCache")
.attr("stack_frames_for_code");
}
}
for (const auto& f : to_symbolize) { for (const auto& f : to_symbolize) {
auto f_code = (PyCodeObject*)f.code; auto f_code = (PyCodeObject*)f.code;
py::handle filename = f_code->co_filename; py::handle filename = f_code->co_filename;
@ -67,6 +77,20 @@ struct PythonTraceback : public CapturedTraceback::Python {
py::cast<std::string>(filename), py::cast<std::string>(filename),
py::cast<std::string>(funcname), py::cast<std::string>(funcname),
(uint64_t)lineno}); (uint64_t)lineno});
// find all the additional frames associated with inductor generated
// code
if (stack_frames_for_code.ptr()) {
py::object extra = stack_frames_for_code(filename, lineno);
if (!extra.is_none()) {
for (py::handle h : extra) {
result.tracebacks.back().push_back(result.all_frames.size());
result.all_frames.emplace_back(unwind::Frame{
py::cast<std::string>(h[filename_s]),
py::cast<std::string>(h[name_s]),
py::cast<uint64_t>(h[line_s])});
}
}
}
} }
} }
}; };