mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove some memory overhead in parallel compile workers (#149168)
Summary: The parallel compile workers are holding on to more memory than they need to because they're loading the compiled modules into memory. Update the post-fork initializer to record when in a subprocess and skip some of the unnecessary overhead. Test Plan: Ran a test script to compile 15k Triton kernels and used tracemalloc in the subprocs to investigate the overhead. On my devgpu: * After importing torch in a subproc: 371M * Without this PR, after compiling 15k kernels: 825M * With this PR, after compiling 15k kernels: 531M Pull Request resolved: https://github.com/pytorch/pytorch/pull/149168 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
e7e477c1f9
commit
c83c711da8
@ -35,7 +35,7 @@ from torch._inductor.codecache import (
|
||||
torch_key,
|
||||
)
|
||||
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
from torch._inductor.compile_worker.utils import _async_compile_initializer
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_set_triton_ptxas_path,
|
||||
_worker_compile_triton,
|
||||
|
@ -53,6 +53,7 @@ from torch._inductor.codegen.rocm.compile_command import (
|
||||
rocm_compile_command,
|
||||
rocm_compiler,
|
||||
)
|
||||
from torch._inductor.compile_worker.utils import in_toplevel_process
|
||||
from torch._inductor.cpp_builder import (
|
||||
_LINKER_SCRIPT,
|
||||
_set_gpu_runtime_env,
|
||||
@ -68,10 +69,7 @@ from torch._inductor.cpp_builder import (
|
||||
from torch._inductor.cpu_vec_isa import pick_vec_isa
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
|
||||
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_reload_python_module,
|
||||
_reload_python_module_in_subproc,
|
||||
)
|
||||
from torch._inductor.runtime.compile_tasks import _reload_python_module
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
||||
from torch._inductor.utils import (
|
||||
ALIGN_BYTES,
|
||||
@ -2745,21 +2743,19 @@ class PyCodeCache:
|
||||
if linemap is None:
|
||||
linemap = []
|
||||
|
||||
mod = _reload_python_module(key, path)
|
||||
in_toplevel = in_toplevel_process()
|
||||
mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
|
||||
|
||||
# unzip into separate lines/nodes lists
|
||||
cls.linemaps[path] = list(zip(*linemap))
|
||||
if in_toplevel:
|
||||
cls.linemaps[path] = list(zip(*linemap))
|
||||
|
||||
if attrs is not None:
|
||||
for k, v in attrs.items():
|
||||
setattr(mod, k, v)
|
||||
|
||||
if not (linemap or attrs):
|
||||
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
|
||||
_reload_python_module_in_subproc, key, path
|
||||
)
|
||||
|
||||
cls.modules.append(mod)
|
||||
if in_toplevel:
|
||||
cls.modules.append(mod)
|
||||
return mod
|
||||
|
||||
@classmethod
|
||||
|
@ -13,7 +13,7 @@ from torch._inductor.compile_worker.subproc_pool import (
|
||||
SubprocMain,
|
||||
SubprocPickler,
|
||||
)
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
from torch._inductor.compile_worker.utils import _async_compile_initializer
|
||||
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ from typing_extensions import Never, ParamSpec
|
||||
# functionality to destroy singletons before forking and re-enable them after.
|
||||
import torch._thread_safe_fork # noqa: F401
|
||||
from torch._inductor import config
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
from torch._inductor.compile_worker.utils import _async_compile_initializer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -5,6 +5,14 @@ from time import sleep
|
||||
from typing import Optional
|
||||
|
||||
|
||||
_IN_TOPLEVEL_PROCESS = True
|
||||
|
||||
|
||||
def in_toplevel_process() -> bool:
|
||||
global _IN_TOPLEVEL_PROCESS
|
||||
return _IN_TOPLEVEL_PROCESS
|
||||
|
||||
|
||||
# If this process dies abnormally (e.g. segfault)
|
||||
# it will not shut down the workers. Instead,
|
||||
# the workers will have their parent reassigned to the
|
||||
@ -28,6 +36,10 @@ def _async_compile_initializer(orig_ppid: int) -> None:
|
||||
# Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
# Set a bit to distinguish async_compile subprocesses from the toplevel process.
|
||||
global _IN_TOPLEVEL_PROCESS
|
||||
_IN_TOPLEVEL_PROCESS = False
|
||||
|
||||
|
||||
_watchdog_thread: Optional[Thread] = None
|
||||
_original_parent: Optional[int] = None
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import linecache
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@ -14,15 +15,9 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
|
||||
def _reload_python_module_in_subproc(key: str, path: str) -> ModuleType:
|
||||
codecache = sys.modules.get("torch._inductor.codecache")
|
||||
if codecache:
|
||||
return codecache.PyCodeCache.load_by_key_path(key, path)
|
||||
else:
|
||||
return _reload_python_module(key, path)
|
||||
|
||||
|
||||
def _reload_python_module(key: str, path: str) -> ModuleType:
|
||||
def _reload_python_module(
|
||||
key: str, path: str, set_sys_modules: bool = True
|
||||
) -> ModuleType:
|
||||
with open(path) as f:
|
||||
try:
|
||||
code = compile(f.read(), path, "exec", dont_inherit=True)
|
||||
@ -34,7 +29,8 @@ def _reload_python_module(key: str, path: str) -> ModuleType:
|
||||
mod.__file__ = path
|
||||
mod.key = key # type: ignore[attr-defined]
|
||||
exec(code, mod.__dict__, mod.__dict__)
|
||||
sys.modules[mod.__name__] = mod
|
||||
if set_sys_modules:
|
||||
sys.modules[mod.__name__] = mod
|
||||
return mod
|
||||
|
||||
|
||||
@ -61,4 +57,6 @@ def _worker_compile_triton(
|
||||
kernel.precompile(warm_cache_only=True)
|
||||
elapsed_ns = time.time_ns() - start_ns
|
||||
kernel.prepare_for_pickle()
|
||||
# We can release this memory in the compile subprocesses:
|
||||
linecache.clearcache()
|
||||
return kernel, elapsed_ns // 1000
|
||||
|
Reference in New Issue
Block a user