mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[memory debugging] Extract frame information from inductor (#95753)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95753 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
e74f70d212
commit
3162f71787
@ -7065,6 +7065,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(module(input), traced(input)))
|
self.assertTrue(torch.allclose(module(input), traced(input)))
|
||||||
|
|
||||||
|
def test_memory_history_inductor(self):
|
||||||
|
def called_inside_compile(x, w, b):
|
||||||
|
a = x @ w + b
|
||||||
|
return torch.sigmoid(a)
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def fn(x, w, b):
|
||||||
|
x = called_inside_compile(x, w, b)
|
||||||
|
return called_inside_compile(x, w, b)
|
||||||
|
|
||||||
|
w = torch.rand(3, 3, device="cuda")
|
||||||
|
b = torch.rand(3, device="cuda")
|
||||||
|
x = torch.rand(3, device="cuda")
|
||||||
|
try:
|
||||||
|
torch.cuda.memory.empty_cache()
|
||||||
|
torch.cuda.memory._record_memory_history(True)
|
||||||
|
r = fn(x, w, b)
|
||||||
|
finally:
|
||||||
|
torch.cuda.memory._record_memory_history(False)
|
||||||
|
snapshot = str(torch.cuda.memory._snapshot())
|
||||||
|
self.assertTrue("called_inside_compile" in snapshot)
|
||||||
|
|
||||||
copy_tests(CommonTemplate, CudaTests, "cuda")
|
copy_tests(CommonTemplate, CudaTests, "cuda")
|
||||||
|
|
||||||
class CudaReproTests(TestCase):
|
class CudaReproTests(TestCase):
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import sys
|
|||||||
import sysconfig
|
import sysconfig
|
||||||
import tempfile
|
import tempfile
|
||||||
import types
|
import types
|
||||||
|
from bisect import bisect_right
|
||||||
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from ctypes import cdll
|
from ctypes import cdll
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -635,10 +636,11 @@ class CppCodeCache:
|
|||||||
|
|
||||||
class PyCodeCache:
|
class PyCodeCache:
|
||||||
cache = dict()
|
cache = dict()
|
||||||
|
linemaps = dict()
|
||||||
clear = staticmethod(cache.clear)
|
clear = staticmethod(cache.clear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, source_code, extra=""):
|
def load(cls, source_code, extra="", linemap=()):
|
||||||
key, path = write(source_code, "py", extra)
|
key, path = write(source_code, "py", extra)
|
||||||
if key not in cls.cache:
|
if key not in cls.cache:
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
@ -654,8 +656,36 @@ class PyCodeCache:
|
|||||||
exec(code, mod.__dict__, mod.__dict__)
|
exec(code, mod.__dict__, mod.__dict__)
|
||||||
# another thread might set this first
|
# another thread might set this first
|
||||||
cls.cache.setdefault(key, mod)
|
cls.cache.setdefault(key, mod)
|
||||||
|
cls.linemaps[path] = linemap
|
||||||
|
|
||||||
return cls.cache[key]
|
return cls.cache[key]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def stack_frames_for_code(cls, path, lineno):
|
||||||
|
if path not in cls.linemaps:
|
||||||
|
return None
|
||||||
|
# [(starting_line, <fx node>), ...]
|
||||||
|
linemap = cls.linemaps[path]
|
||||||
|
p = bisect_right(linemap, lineno, key=lambda x: x[0])
|
||||||
|
if p == 0:
|
||||||
|
return None
|
||||||
|
_, entry = linemap[p - 1]
|
||||||
|
if not entry:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_stack_trace(stack_trace):
|
||||||
|
# ideally fx stores stack traces as data rather than a string
|
||||||
|
# but this is not along a performance critical path
|
||||||
|
regex = r'File "(.+)", line (\d+), in (.+)\n'
|
||||||
|
matches = re.findall(regex, stack_trace)
|
||||||
|
return [
|
||||||
|
{"filename": f, "line": int(l), "name": n}
|
||||||
|
for f, l, n in reversed(matches)
|
||||||
|
]
|
||||||
|
|
||||||
|
return parse_stack_trace(entry.stack_trace)
|
||||||
|
|
||||||
|
|
||||||
class TritonCodeCache:
|
class TritonCodeCache:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from ..utils import (
|
|||||||
cache_on_self,
|
cache_on_self,
|
||||||
get_benchmark_name,
|
get_benchmark_name,
|
||||||
has_triton,
|
has_triton,
|
||||||
|
LineContext,
|
||||||
sympy_dot,
|
sympy_dot,
|
||||||
sympy_product,
|
sympy_product,
|
||||||
sympy_symbol,
|
sympy_symbol,
|
||||||
@ -562,7 +563,7 @@ class WrapperCodeGen(CodeGen):
|
|||||||
|
|
||||||
self.add_benchmark_harness(result)
|
self.add_benchmark_harness(result)
|
||||||
|
|
||||||
return result.getvalue()
|
return result.getvaluewithlinemap()
|
||||||
|
|
||||||
def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
|
def codegen_inputs(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
|
||||||
"""Assign all symbolic shapes to locals"""
|
"""Assign all symbolic shapes to locals"""
|
||||||
@ -736,6 +737,9 @@ class WrapperCodeGen(CodeGen):
|
|||||||
def writeline(self, line):
|
def writeline(self, line):
|
||||||
self.lines.append(line)
|
self.lines.append(line)
|
||||||
|
|
||||||
|
def enter_context(self, ctx):
|
||||||
|
self.lines.append(LineContext(ctx))
|
||||||
|
|
||||||
|
|
||||||
class CppWrapperCodeGen(WrapperCodeGen):
|
class CppWrapperCodeGen(WrapperCodeGen):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -619,11 +619,11 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
def compile_to_module(self):
|
def compile_to_module(self):
|
||||||
from .codecache import PyCodeCache
|
from .codecache import PyCodeCache
|
||||||
|
|
||||||
code = self.codegen()
|
code, linemap = self.codegen()
|
||||||
if config.debug:
|
if config.debug:
|
||||||
print(code)
|
print(code)
|
||||||
|
|
||||||
mod = PyCodeCache.load(code)
|
mod = PyCodeCache.load(code, linemap=linemap)
|
||||||
for name, value in self.constants.items():
|
for name, value in self.constants.items():
|
||||||
setattr(mod, name, value)
|
setattr(mod, name, value)
|
||||||
|
|
||||||
|
|||||||
@ -674,6 +674,10 @@ class Scheduler:
|
|||||||
self.buffer_names_to_free = set()
|
self.buffer_names_to_free = set()
|
||||||
self.buffer_names_no_longer_needed = set()
|
self.buffer_names_no_longer_needed = set()
|
||||||
|
|
||||||
|
# fx graph node to the position it appears in the graph
|
||||||
|
# for debug attribution
|
||||||
|
self.origin_to_index = {}
|
||||||
|
|
||||||
def debug_draw_graph(self):
|
def debug_draw_graph(self):
|
||||||
"""Generate an image of the graph for debugging"""
|
"""Generate an image of the graph for debugging"""
|
||||||
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
|
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
|
||||||
@ -1164,9 +1168,21 @@ class Scheduler:
|
|||||||
self.backends[device] = self.create_backend(device)
|
self.backends[device] = self.create_backend(device)
|
||||||
return self.backends[device]
|
return self.backends[device]
|
||||||
|
|
||||||
|
def enter_context(self, node):
|
||||||
|
def get_order(n):
|
||||||
|
if n not in self.origin_to_index:
|
||||||
|
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
|
||||||
|
return self.origin_to_index[n]
|
||||||
|
|
||||||
|
origins = [(get_order(e), e) for n in node.get_nodes() for e in n.node.origins]
|
||||||
|
if origins:
|
||||||
|
_, last = max(origins)
|
||||||
|
V.graph.wrapper_code.enter_context(last)
|
||||||
|
|
||||||
@dynamo_timed
|
@dynamo_timed
|
||||||
def codegen(self):
|
def codegen(self):
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
|
self.enter_context(node)
|
||||||
self.buffer_names_no_longer_needed.update(node.last_usage)
|
self.buffer_names_no_longer_needed.update(node.last_usage)
|
||||||
|
|
||||||
if not isinstance(node, NopKernelSchedulerNode):
|
if not isinstance(node, NopKernelSchedulerNode):
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import tempfile
|
|||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
@ -420,6 +420,10 @@ def get_dtype_size(dtype):
|
|||||||
return torch.empty((), dtype=dtype).element_size()
|
return torch.empty((), dtype=dtype).element_size()
|
||||||
|
|
||||||
|
|
||||||
|
class LineContext(NamedTuple):
|
||||||
|
context: Any
|
||||||
|
|
||||||
|
|
||||||
class IndentedBuffer:
|
class IndentedBuffer:
|
||||||
tabwidth = 4
|
tabwidth = 4
|
||||||
|
|
||||||
@ -427,19 +431,27 @@ class IndentedBuffer:
|
|||||||
self._lines = []
|
self._lines = []
|
||||||
self._indent = initial_indent
|
self._indent = initial_indent
|
||||||
|
|
||||||
def getvalue(
|
def getvaluewithlinemap(self):
|
||||||
self,
|
|
||||||
):
|
|
||||||
buf = StringIO()
|
buf = StringIO()
|
||||||
|
p = 1
|
||||||
|
linemap = []
|
||||||
for line in self._lines:
|
for line in self._lines:
|
||||||
if isinstance(line, DeferredLineBase):
|
if isinstance(line, DeferredLineBase):
|
||||||
line = line()
|
line = line()
|
||||||
if line is None:
|
if line is None:
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(line, LineContext):
|
||||||
|
linemap.append((p, line.context))
|
||||||
|
continue
|
||||||
assert isinstance(line, str)
|
assert isinstance(line, str)
|
||||||
buf.write(line)
|
buf.write(line)
|
||||||
buf.write("\n")
|
buf.write("\n")
|
||||||
return buf.getvalue()
|
p += 1 + line.count("\n")
|
||||||
|
return buf.getvalue(), linemap
|
||||||
|
|
||||||
|
def getvalue(self):
|
||||||
|
v, _ = self.getvaluewithlinemap()
|
||||||
|
return v
|
||||||
|
|
||||||
def getrawvalue(self):
|
def getrawvalue(self):
|
||||||
buf = StringIO()
|
buf = StringIO()
|
||||||
@ -448,6 +460,8 @@ class IndentedBuffer:
|
|||||||
line = line()
|
line = line()
|
||||||
if line is None:
|
if line is None:
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(line, LineContext):
|
||||||
|
continue
|
||||||
assert isinstance(line, str)
|
assert isinstance(line, str)
|
||||||
# backslash implies line continuation
|
# backslash implies line continuation
|
||||||
if line.endswith("\\"):
|
if line.endswith("\\"):
|
||||||
@ -467,7 +481,9 @@ class IndentedBuffer:
|
|||||||
return " " * (self._indent * self.tabwidth)
|
return " " * (self._indent * self.tabwidth)
|
||||||
|
|
||||||
def writeline(self, line):
|
def writeline(self, line):
|
||||||
if isinstance(line, DeferredLineBase):
|
if isinstance(line, LineContext):
|
||||||
|
self._lines.append(line)
|
||||||
|
elif isinstance(line, DeferredLineBase):
|
||||||
self._lines.append(line.with_prefix(self.prefix()))
|
self._lines.append(line.with_prefix(self.prefix()))
|
||||||
elif line.strip():
|
elif line.strip():
|
||||||
self._lines.append(f"{self.prefix()}{line}")
|
self._lines.append(f"{self.prefix()}{line}")
|
||||||
@ -493,12 +509,15 @@ class IndentedBuffer:
|
|||||||
if isinstance(other_code, IndentedBuffer):
|
if isinstance(other_code, IndentedBuffer):
|
||||||
dedent = float("inf")
|
dedent = float("inf")
|
||||||
for line in other_code._lines:
|
for line in other_code._lines:
|
||||||
if line:
|
if not isinstance(line, LineContext) and line:
|
||||||
dedent = min(dedent, len(line) - len(line.lstrip()))
|
dedent = min(dedent, len(line) - len(line.lstrip()))
|
||||||
if math.isinf(dedent):
|
if math.isinf(dedent):
|
||||||
dedent = 0
|
dedent = 0
|
||||||
for line in other_code._lines:
|
for line in other_code._lines:
|
||||||
IndentedBuffer.writeline(self, line[dedent:])
|
if isinstance(line, LineContext):
|
||||||
|
self._lines.append(line)
|
||||||
|
else:
|
||||||
|
IndentedBuffer.writeline(self, line[dedent:])
|
||||||
else:
|
else:
|
||||||
other_code = textwrap.dedent(other_code)
|
other_code = textwrap.dedent(other_code)
|
||||||
if strip:
|
if strip:
|
||||||
|
|||||||
@ -56,6 +56,16 @@ struct PythonTraceback : public CapturedTraceback::Python {
|
|||||||
py::str name_s = "name";
|
py::str name_s = "name";
|
||||||
py::str filename_s = "filename";
|
py::str filename_s = "filename";
|
||||||
|
|
||||||
|
auto torch = py::module::import("torch");
|
||||||
|
py::object stack_frames_for_code;
|
||||||
|
if (py::hasattr(torch, "_inductor")) {
|
||||||
|
py::object inductor = torch.attr("_inductor");
|
||||||
|
if (py::hasattr(inductor, "codecache")) {
|
||||||
|
stack_frames_for_code = inductor.attr("codecache")
|
||||||
|
.attr("PyCodeCache")
|
||||||
|
.attr("stack_frames_for_code");
|
||||||
|
}
|
||||||
|
}
|
||||||
for (const auto& f : to_symbolize) {
|
for (const auto& f : to_symbolize) {
|
||||||
auto f_code = (PyCodeObject*)f.code;
|
auto f_code = (PyCodeObject*)f.code;
|
||||||
py::handle filename = f_code->co_filename;
|
py::handle filename = f_code->co_filename;
|
||||||
@ -67,6 +77,20 @@ struct PythonTraceback : public CapturedTraceback::Python {
|
|||||||
py::cast<std::string>(filename),
|
py::cast<std::string>(filename),
|
||||||
py::cast<std::string>(funcname),
|
py::cast<std::string>(funcname),
|
||||||
(uint64_t)lineno});
|
(uint64_t)lineno});
|
||||||
|
// find all the additional frames associated with inductor generated
|
||||||
|
// code
|
||||||
|
if (stack_frames_for_code.ptr()) {
|
||||||
|
py::object extra = stack_frames_for_code(filename, lineno);
|
||||||
|
if (!extra.is_none()) {
|
||||||
|
for (py::handle h : extra) {
|
||||||
|
result.tracebacks.back().push_back(result.all_frames.size());
|
||||||
|
result.all_frames.emplace_back(unwind::Frame{
|
||||||
|
py::cast<std::string>(h[filename_s]),
|
||||||
|
py::cast<std::string>(h[name_s]),
|
||||||
|
py::cast<uint64_t>(h[line_s])});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user