Files
pytorch/test/inductor/test_compile.py
Xu Han c71950907d [inductor] add _get_inductor_debug_symbol_cflags for debug symbol control. (#159938)
We need to add inductor debug symbol support for crash case debug. When we turn on generate debug symbol.
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
On Linux, it should create some debug sections in binary file.

I added UT for it also.

It works well on Windows inductor debug.
<img width="1648" height="833" alt="image" src="https://github.com/user-attachments/assets/5282a7de-cef3-4a38-9cd4-a0e63482c8b6" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159938
Approved by: https://github.com/jansel, https://github.com/angelayi
2025-08-06 19:31:45 +00:00

190 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Owner(s): ["module: inductor"]
import os
import shlex
import subprocess
import sys
from unittest import mock
import torch
from torch import _dynamo as dynamo, _inductor as inductor
from torch._inductor.codecache import write
from torch._inductor.cpp_builder import CppBuilder, CppOptions
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import gen_gm_and_inputs
from torch.fx import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.inductor_utils import HAS_CPU
_IS_MACOS = sys.platform.startswith("darwin")
_IS_WINDOWS = sys.platform == "win32"
def safe_command_output(cmd, timeout=30):
try:
return subprocess.check_output(
cmd,
stderr=subprocess.STDOUT,
text=True,
timeout=timeout,
shell=isinstance(cmd, str),
).strip()
except subprocess.CalledProcessError as e:
return f"run failederror code {e.returncode}: {e.output.strip()}"
except subprocess.TimeoutExpired:
return "runt timeout"
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = torch.nn.Linear(10, 10)
self.b = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(self.a(x))
x = torch.sigmoid(self.b(x))
return x
class MyModule2(MyModule):
def forward(self, x): # takes a dict of list
a, b = x["key"]
return {"result": super().forward(a) + b}
class MyModule3(MyModule):
def forward(self, x):
return (super().forward(x),)
class TestStandaloneInductor(TestCase):
"""
These test check that you can call TorchInductor directly without
going through TorchDynamo.
"""
def test_inductor_via_fx(self):
mod = MyModule3().eval()
inp = torch.randn(10)
correct = mod(inp)
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
actual = mod_opt(inp)
self.assertEqual(actual, correct)
def test_inductor_via_fx_tensor_return(self):
mod = MyModule().eval()
inp = torch.randn(10)
correct = mod(inp)
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
actual = mod_opt(inp)
self.assertEqual(actual, correct)
def test_inductor_via_fx_dict_input(self):
mod = MyModule2().eval()
inp = {"key": [torch.randn(10), torch.randn(10)]}
correct = mod(inp)
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
actual = mod_opt(inp)
self.assertEqual(actual, correct)
def test_inductor_via_make_fx(self):
mod = MyModule().eval()
inp = torch.randn(10)
correct = mod(inp)
mod_opt = inductor.compile(make_fx(mod)(inp), [inp])
actual = mod_opt(inp)
self.assertEqual(actual, correct)
def test_inductor_via_bare_module(self):
mod = MyModule3().eval()
inp = torch.randn(10)
correct = mod(inp)
# no FX graph at all (mod must return list/tuple in this case)
mod_opt = inductor.compile(mod, [inp])
actual = mod_opt(inp)
self.assertEqual(actual, correct)
def test_inductor_via_export1(self):
mod = MyModule3().eval()
inp = torch.randn(10)
correct = mod(inp)
gm, _ = dynamo.export(mod, inp, aten_graph=True)
mod_opt = inductor.compile(gm, [inp])
actual = mod_opt(inp)
self.assertEqual(actual, correct)
def test_inductor_via_export2(self):
mod = MyModule2().eval()
inp = {"key": [torch.randn(10), torch.randn(10)]}
correct = mod(inp)
gm, _ = dynamo.export(mod, inp)
mod_opt = inductor.compile(gm, [inp])
actual = mod_opt(inp)
self.assertEqual(actual, correct)
def test_inductor_via_op_with_multiple_outputs(self):
x1 = torch.randn((2, 512, 128))
x2 = [128]
x3 = torch.randn(128)
x4 = torch.randn((128,))
x5 = 1e-6
mod, inp = gen_gm_and_inputs(
torch.ops.aten.native_layer_norm.default, (x1, x2, x3, x4, x5), {}
)
mod_opt = inductor.compile(mod, inp)
self.assertEqual(mod(*inp), mod_opt(*inp))
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"})
def test_inductor_generate_debug_symbol(self):
cpp_code = """
int main(){
return 0;
}
"""
_, source_path = write(
cpp_code,
"cpp",
)
build_option = CppOptions()
cpp_builder = CppBuilder(
name="test_symbol",
sources=source_path,
output_dir=os.path.dirname(source_path),
BuildOption=build_option,
)
cpp_builder.build()
binary_path = cpp_builder.get_target_file_path()
"""
When we turn on generate debug symbol.
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
On Linux, it should create some debug sections in binary file.
"""
def check_linux_debug_section(module_path: str):
check_cmd = shlex.split(f"readelf -S {module_path}")
output = safe_command_output(check_cmd)
has_debug_sym = ".debug_info" in output
self.assertEqual(has_debug_sym, True)
def check_windows_pdb_exist(module_path: str):
file_name_no_ext = os.path.splitext(module_path)[0]
file_name_pdb = f"{file_name_no_ext}.pdb"
has_pdb_file = os.path.exists(file_name_pdb)
self.assertEqual(has_pdb_file, True)
if _IS_WINDOWS:
check_windows_pdb_exist(binary_path)
elif _IS_MACOS:
pass # MacOS not sure that if it should be works.
else:
check_linux_debug_section(binary_path)
if __name__ == "__main__":
if HAS_CPU:
run_tests()