Remove some memory overhead in parallel compile workers (#149168)

Summary: The parallel compile workers are holding on to more memory than they need to because they're loading the compiled modules into memory. Update the post-fork initializer to record when in a subprocess and skip some of the unnecessary overhead.

Test Plan: Ran a test script to compile 15k Triton kernels and used tracemalloc in the subprocs to investigate the overhead. On my devgpu:
* After importing torch in a subproc: 371M
* Without this PR, after compiling 15k kernels: 825M
* With this PR, after compiling 15k kernels: 531M

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149168
Approved by: https://github.com/jansel
This commit is contained in:
Sam Larsen
2025-03-14 15:41:51 -07:00
committed by PyTorch MergeBot
parent e7e477c1f9
commit c83c711da8
6 changed files with 31 additions and 25 deletions

View File

@ -35,7 +35,7 @@ from torch._inductor.codecache import (
torch_key,
)
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.compile_worker.utils import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import (
_set_triton_ptxas_path,
_worker_compile_triton,

View File

@ -53,6 +53,7 @@ from torch._inductor.codegen.rocm.compile_command import (
rocm_compile_command,
rocm_compiler,
)
from torch._inductor.compile_worker.utils import in_toplevel_process
from torch._inductor.cpp_builder import (
_LINKER_SCRIPT,
_set_gpu_runtime_env,
@ -68,10 +69,7 @@ from torch._inductor.cpp_builder import (
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
from torch._inductor.runtime.compile_tasks import (
_reload_python_module,
_reload_python_module_in_subproc,
)
from torch._inductor.runtime.compile_tasks import _reload_python_module
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import (
ALIGN_BYTES,
@ -2745,21 +2743,19 @@ class PyCodeCache:
if linemap is None:
linemap = []
mod = _reload_python_module(key, path)
in_toplevel = in_toplevel_process()
mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
# unzip into separate lines/nodes lists
cls.linemaps[path] = list(zip(*linemap))
if in_toplevel:
cls.linemaps[path] = list(zip(*linemap))
if attrs is not None:
for k, v in attrs.items():
setattr(mod, k, v)
if not (linemap or attrs):
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
_reload_python_module_in_subproc, key, path
)
cls.modules.append(mod)
if in_toplevel:
cls.modules.append(mod)
return mod
@classmethod

View File

@ -13,7 +13,7 @@ from torch._inductor.compile_worker.subproc_pool import (
SubprocMain,
SubprocPickler,
)
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.compile_worker.utils import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path

View File

@ -21,7 +21,7 @@ from typing_extensions import Never, ParamSpec
# functionality to destroy singletons before forking and re-enable them after.
import torch._thread_safe_fork # noqa: F401
from torch._inductor import config
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.compile_worker.utils import _async_compile_initializer
log = logging.getLogger(__name__)

View File

@ -5,6 +5,14 @@ from time import sleep
from typing import Optional
_IN_TOPLEVEL_PROCESS = True
def in_toplevel_process() -> bool:
global _IN_TOPLEVEL_PROCESS
return _IN_TOPLEVEL_PROCESS
# If this process dies abnormally (e.g. segfault)
# it will not shut down the workers. Instead,
# the workers will have their parent reassigned to the
@ -28,6 +36,10 @@ def _async_compile_initializer(orig_ppid: int) -> None:
# Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam.
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Set a bit to distinguish async_compile subprocesses from the toplevel process.
global _IN_TOPLEVEL_PROCESS
_IN_TOPLEVEL_PROCESS = False
_watchdog_thread: Optional[Thread] = None
_original_parent: Optional[int] = None

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import functools
import linecache
import os
import sys
import time
@ -14,15 +15,9 @@ if TYPE_CHECKING:
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
def _reload_python_module_in_subproc(key: str, path: str) -> ModuleType:
codecache = sys.modules.get("torch._inductor.codecache")
if codecache:
return codecache.PyCodeCache.load_by_key_path(key, path)
else:
return _reload_python_module(key, path)
def _reload_python_module(key: str, path: str) -> ModuleType:
def _reload_python_module(
key: str, path: str, set_sys_modules: bool = True
) -> ModuleType:
with open(path) as f:
try:
code = compile(f.read(), path, "exec", dont_inherit=True)
@ -34,7 +29,8 @@ def _reload_python_module(key: str, path: str) -> ModuleType:
mod.__file__ = path
mod.key = key # type: ignore[attr-defined]
exec(code, mod.__dict__, mod.__dict__)
sys.modules[mod.__name__] = mod
if set_sys_modules:
sys.modules[mod.__name__] = mod
return mod
@ -61,4 +57,6 @@ def _worker_compile_triton(
kernel.precompile(warm_cache_only=True)
elapsed_ns = time.time_ns() - start_ns
kernel.prepare_for_pickle()
# We can release this memory in the compile subprocesses:
linecache.clearcache()
return kernel, elapsed_ns // 1000