mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
We need to add inductor debug symbol support for crash case debug. When we turn on generate debug symbol. On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. On Linux, it should create some debug sections in binary file. I added UT for it also. It works well on Windows inductor debug. <img width="1648" height="833" alt="image" src="https://github.com/user-attachments/assets/5282a7de-cef3-4a38-9cd4-a0e63482c8b6" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/159938 Approved by: https://github.com/jansel, https://github.com/angelayi
190 lines
5.8 KiB
Python
190 lines
5.8 KiB
Python
# Owner(s): ["module: inductor"]
|
||
import os
|
||
import shlex
|
||
import subprocess
|
||
import sys
|
||
from unittest import mock
|
||
|
||
import torch
|
||
from torch import _dynamo as dynamo, _inductor as inductor
|
||
from torch._inductor.codecache import write
|
||
from torch._inductor.cpp_builder import CppBuilder, CppOptions
|
||
from torch._inductor.test_case import run_tests, TestCase
|
||
from torch._inductor.utils import gen_gm_and_inputs
|
||
from torch.fx import symbolic_trace
|
||
from torch.fx.experimental.proxy_tensor import make_fx
|
||
from torch.testing._internal.inductor_utils import HAS_CPU
|
||
|
||
|
||
_IS_MACOS = sys.platform.startswith("darwin")
|
||
_IS_WINDOWS = sys.platform == "win32"
|
||
|
||
|
||
def safe_command_output(cmd, timeout=30):
|
||
try:
|
||
return subprocess.check_output(
|
||
cmd,
|
||
stderr=subprocess.STDOUT,
|
||
text=True,
|
||
timeout=timeout,
|
||
shell=isinstance(cmd, str),
|
||
).strip()
|
||
except subprocess.CalledProcessError as e:
|
||
return f"run failed(error code {e.returncode}): {e.output.strip()}"
|
||
except subprocess.TimeoutExpired:
|
||
return "runt timeout"
|
||
|
||
|
||
class MyModule(torch.nn.Module):
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.a = torch.nn.Linear(10, 10)
|
||
self.b = torch.nn.Linear(10, 10)
|
||
self.relu = torch.nn.ReLU()
|
||
|
||
def forward(self, x):
|
||
x = self.relu(self.a(x))
|
||
x = torch.sigmoid(self.b(x))
|
||
return x
|
||
|
||
|
||
class MyModule2(MyModule):
|
||
def forward(self, x): # takes a dict of list
|
||
a, b = x["key"]
|
||
return {"result": super().forward(a) + b}
|
||
|
||
|
||
class MyModule3(MyModule):
|
||
def forward(self, x):
|
||
return (super().forward(x),)
|
||
|
||
|
||
class TestStandaloneInductor(TestCase):
|
||
"""
|
||
These test check that you can call TorchInductor directly without
|
||
going through TorchDynamo.
|
||
"""
|
||
|
||
def test_inductor_via_fx(self):
|
||
mod = MyModule3().eval()
|
||
inp = torch.randn(10)
|
||
correct = mod(inp)
|
||
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
|
||
actual = mod_opt(inp)
|
||
self.assertEqual(actual, correct)
|
||
|
||
def test_inductor_via_fx_tensor_return(self):
|
||
mod = MyModule().eval()
|
||
inp = torch.randn(10)
|
||
correct = mod(inp)
|
||
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
|
||
actual = mod_opt(inp)
|
||
self.assertEqual(actual, correct)
|
||
|
||
def test_inductor_via_fx_dict_input(self):
|
||
mod = MyModule2().eval()
|
||
inp = {"key": [torch.randn(10), torch.randn(10)]}
|
||
correct = mod(inp)
|
||
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
|
||
actual = mod_opt(inp)
|
||
self.assertEqual(actual, correct)
|
||
|
||
def test_inductor_via_make_fx(self):
|
||
mod = MyModule().eval()
|
||
inp = torch.randn(10)
|
||
correct = mod(inp)
|
||
mod_opt = inductor.compile(make_fx(mod)(inp), [inp])
|
||
actual = mod_opt(inp)
|
||
self.assertEqual(actual, correct)
|
||
|
||
def test_inductor_via_bare_module(self):
|
||
mod = MyModule3().eval()
|
||
inp = torch.randn(10)
|
||
correct = mod(inp)
|
||
# no FX graph at all (mod must return list/tuple in this case)
|
||
mod_opt = inductor.compile(mod, [inp])
|
||
actual = mod_opt(inp)
|
||
self.assertEqual(actual, correct)
|
||
|
||
def test_inductor_via_export1(self):
|
||
mod = MyModule3().eval()
|
||
inp = torch.randn(10)
|
||
correct = mod(inp)
|
||
gm, _ = dynamo.export(mod, inp, aten_graph=True)
|
||
mod_opt = inductor.compile(gm, [inp])
|
||
actual = mod_opt(inp)
|
||
self.assertEqual(actual, correct)
|
||
|
||
def test_inductor_via_export2(self):
|
||
mod = MyModule2().eval()
|
||
inp = {"key": [torch.randn(10), torch.randn(10)]}
|
||
correct = mod(inp)
|
||
gm, _ = dynamo.export(mod, inp)
|
||
mod_opt = inductor.compile(gm, [inp])
|
||
actual = mod_opt(inp)
|
||
self.assertEqual(actual, correct)
|
||
|
||
def test_inductor_via_op_with_multiple_outputs(self):
|
||
x1 = torch.randn((2, 512, 128))
|
||
x2 = [128]
|
||
x3 = torch.randn(128)
|
||
x4 = torch.randn((128,))
|
||
x5 = 1e-6
|
||
mod, inp = gen_gm_and_inputs(
|
||
torch.ops.aten.native_layer_norm.default, (x1, x2, x3, x4, x5), {}
|
||
)
|
||
mod_opt = inductor.compile(mod, inp)
|
||
self.assertEqual(mod(*inp), mod_opt(*inp))
|
||
|
||
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"})
|
||
def test_inductor_generate_debug_symbol(self):
|
||
cpp_code = """
|
||
int main(){
|
||
return 0;
|
||
}
|
||
"""
|
||
|
||
_, source_path = write(
|
||
cpp_code,
|
||
"cpp",
|
||
)
|
||
build_option = CppOptions()
|
||
cpp_builder = CppBuilder(
|
||
name="test_symbol",
|
||
sources=source_path,
|
||
output_dir=os.path.dirname(source_path),
|
||
BuildOption=build_option,
|
||
)
|
||
cpp_builder.build()
|
||
binary_path = cpp_builder.get_target_file_path()
|
||
|
||
"""
|
||
When we turn on generate debug symbol.
|
||
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
|
||
On Linux, it should create some debug sections in binary file.
|
||
"""
|
||
|
||
def check_linux_debug_section(module_path: str):
|
||
check_cmd = shlex.split(f"readelf -S {module_path}")
|
||
output = safe_command_output(check_cmd)
|
||
has_debug_sym = ".debug_info" in output
|
||
self.assertEqual(has_debug_sym, True)
|
||
|
||
def check_windows_pdb_exist(module_path: str):
|
||
file_name_no_ext = os.path.splitext(module_path)[0]
|
||
file_name_pdb = f"{file_name_no_ext}.pdb"
|
||
has_pdb_file = os.path.exists(file_name_pdb)
|
||
self.assertEqual(has_pdb_file, True)
|
||
|
||
if _IS_WINDOWS:
|
||
check_windows_pdb_exist(binary_path)
|
||
elif _IS_MACOS:
|
||
pass # MacOS not sure that if it should be works.
|
||
else:
|
||
check_linux_debug_section(binary_path)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
if HAS_CPU:
|
||
run_tests()
|