Compare commits

...

1 Commits

Author SHA1 Message Date
ff5cf2b3a2 Inductor: Prefer libgomp from Pytorch instead of system one 2025-11-10 11:26:34 +00:00

View File

@ -19,6 +19,7 @@ import sysconfig
import tempfile
import textwrap
import warnings
import glob
from collections.abc import Sequence
from ctypes import cdll, wintypes
from ctypes.util import find_library
@ -1320,8 +1321,26 @@ def _get_openmp_args(
elif _is_intel_compiler(cpp_compiler):
cflags.append("fiopenmp")
else:
cflags.append("fopenmp")
libs.append("gomp")
# GCC on Linux
# Explicitly control OpenMP linkage and prefer torch libgomp if available
cflags += ["fopenmp", "Wno-unknown-pragmas"]
torch_root = Path(torch.__file__).resolve().parent
for d in [torch_root / "lib", (torch_root / ".." / "torch.libs").resolve()]:
torch_libgomp = glob.glob(str(d / "libgomp-*.so*"))
if torch_libgomp:
path = Path(torch_libgomp[0]).resolve()
ldflags += [
f"-L{path.parent}",
f"-Wl,-rpath,{path.parent}",
"-Wl,--disable-new-dtags",
"-Wl,--no-as-needed",
]
libs.append(f":{path.name}")
lib_dir_paths.append(str(path.parent))
print(f"[DEBUG] Using torch libgomp from: {path}")
break
else:
libs.append("gomp")
return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args
@ -1987,6 +2006,29 @@ class CppBuilder:
passthrough_args=self._passthrough_parameters_args,
output=self._output,
)
# Begin torch libgomp preference (Linux/aarch64)
try:
if platform.system() == "Linux" and platform.machine().lower() in ("aarch64", "arm64"):
torch_root = Path(torch.__file__).resolve().parent
for d in [torch_root / "lib", (torch_root / ".." / "torch.libs").resolve()]:
torch_libgomp = glob.glob(str(d / "libgomp-*.so*"))
if torch_libgomp:
path = Path(torch_libgomp[0]).resolve()
os.environ["LD_PRELOAD"] = str(path)
print(f"[DEBUG] LD_PRELOAD set to {path}")
break
except Exception as e:
print(f"[DEBUG] torch libgomp handling failed: {e}")
# Fix malformed flags (e.g., '--L' instead of '-L')
command_line = (
command_line.replace("--L", "-L")
.replace("--Wl,", "-Wl,")
.replace("--Wl", "-Wl,")
.replace("-Wl-", "-Wl,")
)
print(f"[DEBUG] Final command line: {command_line}")
return command_line
def get_target_file_path(self) -> str: