mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 18:04:58 +08:00
In this PR, we replace OMP SIMD with `aten::vec` to optimize TorchInductor vectorization performance. Take `res=torch.exp(torch.add(x, y))` as the example. The generated code is as follows if `config.cpp.simdlen` is 8.
```C++
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0,
const long ks0,
const long ks1)
{
#pragma omp parallel num_threads(48)
{
#pragma omp for
for(long i0=0; i0<((ks0*ks1) / 8); ++i0)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 8*i0);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 8*i0);
auto tmp2 = tmp0 + tmp1;
auto tmp3 = tmp2.exp();
tmp3.store(out_ptr0 + 8*i0);
}
#pragma omp for simd simdlen(4)
for(long i0=8*(((ks0*ks1) / 8)); i0<ks0*ks1; ++i0)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = in_ptr1[i0];
auto tmp2 = tmp0 + tmp1;
auto tmp3 = std::exp(tmp2);
out_ptr0[i0] = tmp3;
}
}
}
```
The major pipeline is as follows.
- Check whether the loop body could be vectorized by `aten::vec`. The checker consists of two parts. [One ](bf66991fc4/torch/_inductor/codegen/cpp.py (L702))is to check whether all the `ops` have been supported. The [other one](355326faa3/torch/_inductor/codegen/cpp.py (L672)) is to check whether the data access could be vectorized.
- [`CppSimdVecKernelChecker`](355326faa3/torch/_inductor/codegen/cpp.py (L655))
- Create the `aten::vec` kernel and original omp simd kernel. Regarding the original omp simd kernel, it serves for the tail loop when the loop is vectorized.
- [`CppSimdVecKernel`](355326faa3/torch/_inductor/codegen/cpp.py (L601))
- [`CppSimdVecOverrides`](355326faa3/torch/_inductor/codegen/cpp.py (L159)): The ops that we have supported on the top of `aten::vec`
- Create kernel
- [`aten::vec` kernel](355326faa3/torch/_inductor/codegen/cpp.py (L924))
- [`Original CPP kernel - OMP SIMD`](355326faa3/torch/_inductor/codegen/cpp.py (L929))
- Generate code
- [`CppKernelProxy`](355326faa3/torch/_inductor/codegen/cpp.py (L753)) is used to combine the `aten::vec` kernel and original cpp kernel
- [Vectorize the most inner loop](355326faa3/torch/_inductor/codegen/cpp.py (L753))
- [Generate code](355326faa3/torch/_inductor/codegen/cpp.py (L821))
Next steps:
- [x] Support reduction
- [x] Vectorize the tail loop with `aten::vec`
- [ ] Support BF16
- [ ] Optimize the loop condition and loop index calculation by replacing `div` with `add`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87068
Approved by: https://github.com/jgong5, https://github.com/jansel
468 lines
15 KiB
Python
468 lines
15 KiB
Python
import base64
|
|
import enum
|
|
import functools
|
|
import getpass
|
|
import hashlib
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
import re
|
|
import shutil
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import sysconfig
|
|
import tempfile
|
|
import types
|
|
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
|
from ctypes import cdll
|
|
from threading import Thread
|
|
from time import sleep, time
|
|
from typing import Any, Dict
|
|
|
|
import torch
|
|
from torch.utils import cpp_extension
|
|
|
|
from . import config, cuda_properties, exc
|
|
|
|
LOCK_TIMEOUT = 600
|
|
|
|
# timing metrics for time spent in the compilation
|
|
_cumulative_compile_time = 0
|
|
_t0 = None
|
|
|
|
|
|
def _compile_start():
|
|
global _t0
|
|
if _t0 is None:
|
|
_t0 = time()
|
|
|
|
|
|
def _compile_end():
|
|
global _cumulative_compile_time, _t0
|
|
if _t0 is not None:
|
|
t1 = time()
|
|
_cumulative_compile_time += t1 - _t0
|
|
_t0 = None
|
|
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def cache_dir():
|
|
return os.environ.get(
|
|
"TORCHINDUCTOR_CACHE_DIR", f"/tmp/torchinductor_{getpass.getuser()}"
|
|
)
|
|
|
|
|
|
def get_lock_dir():
|
|
lock_dir = os.path.join(cache_dir(), "locks")
|
|
if not os.path.exists(lock_dir):
|
|
os.makedirs(lock_dir, exist_ok=True)
|
|
return lock_dir
|
|
|
|
|
|
def code_hash(code):
|
|
return (
|
|
"c"
|
|
+ base64.b32encode(hashlib.sha256(code.encode("utf-8")).digest())[:51]
|
|
.decode("utf-8")
|
|
.lower()
|
|
)
|
|
|
|
|
|
def write(source_code, ext, extra=""):
|
|
basename = code_hash(source_code + extra)
|
|
subdir = os.path.join(cache_dir(), basename[1:3])
|
|
if not os.path.exists(subdir):
|
|
os.makedirs(subdir, exist_ok=True)
|
|
path = os.path.join(subdir, f"{basename}.{ext}")
|
|
if not os.path.exists(path):
|
|
# use a temp file for thread safety
|
|
fd, tmp_path = tempfile.mkstemp(dir=subdir)
|
|
with os.fdopen(fd, "w") as f:
|
|
f.write(source_code)
|
|
os.rename(tmp_path, path)
|
|
return basename, path
|
|
|
|
|
|
def cpp_compiler():
|
|
if isinstance(config.cpp.cxx, (list, tuple)):
|
|
search = tuple(config.cpp.cxx)
|
|
else:
|
|
search = (config.cpp.cxx,)
|
|
return cpp_compiler_search(search)
|
|
|
|
|
|
@functools.lru_cache(1)
|
|
def cpp_compiler_search(search):
|
|
for cxx in search:
|
|
try:
|
|
if cxx is None:
|
|
from filelock import FileLock
|
|
|
|
lock_dir = get_lock_dir()
|
|
lock = FileLock(
|
|
os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
|
|
)
|
|
with lock:
|
|
cxx = install_gcc_via_conda()
|
|
subprocess.check_output([cxx, "--version"])
|
|
return cxx
|
|
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
|
|
continue
|
|
raise exc.InvalidCxxCompiler()
|
|
|
|
|
|
def install_gcc_via_conda():
|
|
"""On older systems, this is a quick way to get a modern compiler"""
|
|
prefix = os.path.join(cache_dir(), "gcc")
|
|
cxx_path = os.path.join(prefix, "bin", "g++")
|
|
if not os.path.exists(cxx_path):
|
|
log.info("Downloading GCC via conda")
|
|
conda = os.environ.get("CONDA_EXE", "conda")
|
|
if conda is None:
|
|
conda = shutil.which("conda")
|
|
if conda is not None:
|
|
subprocess.check_call(
|
|
[
|
|
conda,
|
|
"create",
|
|
f"--prefix={prefix}",
|
|
"--channel=conda-forge",
|
|
"--quiet",
|
|
"-y",
|
|
"python=3.8",
|
|
"gxx",
|
|
],
|
|
stdout=subprocess.PIPE,
|
|
)
|
|
return cxx_path
|
|
|
|
|
|
def is_gcc():
|
|
return re.search(r"(gcc|g\+\+)", cpp_compiler())
|
|
|
|
|
|
class _SupportedVecIsa(enum.Enum):
|
|
AVX512 = 1
|
|
AVX2 = 2
|
|
INVALID = -1
|
|
|
|
def __bool__(self):
|
|
return self != _SupportedVecIsa.INVALID
|
|
|
|
@staticmethod
|
|
def isa_str(supported_isa: enum.Enum):
|
|
if supported_isa == _SupportedVecIsa.AVX512:
|
|
return "avx512"
|
|
elif supported_isa == _SupportedVecIsa.AVX2:
|
|
return "avx2"
|
|
else:
|
|
return ""
|
|
|
|
@staticmethod
|
|
def vec_macro(supported_isa: enum.Enum):
|
|
if supported_isa == _SupportedVecIsa.AVX512:
|
|
return "CPU_CAPABILITY_AVX512"
|
|
elif supported_isa == _SupportedVecIsa.AVX2:
|
|
return "CPU_CAPABILITY_AVX2"
|
|
else:
|
|
return ""
|
|
|
|
|
|
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
|
# might have too much redundant content that is useless for ISA check. Hence,
|
|
# we only cache some key isa information.
|
|
@functools.lru_cache(1)
|
|
def get_cpu_proc_info():
|
|
if sys.platform != "linux":
|
|
return []
|
|
|
|
isa_info = []
|
|
with open("/proc/cpuinfo") as _cpu_info:
|
|
_cpu_info_content = _cpu_info.read()
|
|
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX512) in _cpu_info_content:
|
|
isa_info.append(_SupportedVecIsa.AVX512)
|
|
|
|
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX2) in _cpu_info_content:
|
|
isa_info.append(_SupportedVecIsa.AVX2)
|
|
|
|
return isa_info
|
|
|
|
|
|
def supported_vector_isa():
|
|
# TODO: Add ARM Vec here.
|
|
# Dict(k: isa, v: number of float element)
|
|
vec_isa_info = {
|
|
_SupportedVecIsa.AVX512: 16,
|
|
_SupportedVecIsa.AVX2: 8,
|
|
}
|
|
|
|
if config.cpp.simdlen is None or config.cpp.simdlen <= 1:
|
|
return _SupportedVecIsa.INVALID
|
|
|
|
cpu_info_content = get_cpu_proc_info()
|
|
for isa in vec_isa_info.keys():
|
|
if isa in cpu_info_content and config.cpp.simdlen == vec_isa_info[isa]:
|
|
return isa
|
|
|
|
return _SupportedVecIsa.INVALID
|
|
|
|
|
|
def cpp_compile_command(input, output, include_pytorch=False):
|
|
valid_isa = supported_vector_isa()
|
|
if include_pytorch or valid_isa:
|
|
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
|
|
lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")]
|
|
libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"]
|
|
macros = _SupportedVecIsa.vec_macro(valid_isa)
|
|
if macros:
|
|
macros = f"-D{macros}"
|
|
else:
|
|
# Note - this is effectively a header only inclusion. Usage of some header files may result in
|
|
# symbol not found, if those header files require a library.
|
|
# For those cases, include the lpath and libs command as we do for pytorch above.
|
|
# This approach allows us to only pay for what we use.
|
|
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
|
|
lpaths = []
|
|
libs = ["gomp"]
|
|
macros = ""
|
|
ipaths = " ".join(["-I" + p for p in ipaths])
|
|
lpaths = " ".join(["-L" + p for p in lpaths])
|
|
libs = " ".join(["-l" + p for p in libs])
|
|
|
|
return re.sub(
|
|
r"[ \n]+",
|
|
" ",
|
|
f"""
|
|
{cpp_compiler()} {input} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable
|
|
{ipaths} {lpaths} {libs} {macros}
|
|
-march=native -O3 -ffast-math -fno-finite-math-only -fopenmp
|
|
-o{output}
|
|
""",
|
|
).strip()
|
|
|
|
|
|
class CppCodeCache:
|
|
cache = dict()
|
|
clear = staticmethod(cache.clear)
|
|
|
|
@classmethod
|
|
def load(cls, source_code):
|
|
key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o"))
|
|
if key not in cls.cache:
|
|
from filelock import FileLock
|
|
|
|
lock_dir = get_lock_dir()
|
|
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
|
with lock:
|
|
output_path = input_path[:-3] + "so"
|
|
if not os.path.exists(output_path):
|
|
cmd = cpp_compile_command(
|
|
input=input_path, output=output_path
|
|
).split(" ")
|
|
try:
|
|
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
|
|
except subprocess.CalledProcessError as e:
|
|
raise exc.CppCompileError(cmd, e.output)
|
|
|
|
cls.cache[key] = cdll.LoadLibrary(output_path)
|
|
cls.cache[key].key = key
|
|
|
|
return cls.cache[key]
|
|
|
|
|
|
class PyCodeCache:
|
|
cache = dict()
|
|
clear = staticmethod(cache.clear)
|
|
|
|
@classmethod
|
|
def load(cls, source_code):
|
|
key, path = write(source_code, "py")
|
|
if key not in cls.cache:
|
|
with open(path) as f:
|
|
code = compile(f.read(), path, "exec")
|
|
mod = types.ModuleType(f"{__name__}.{key}")
|
|
mod.__file__ = path
|
|
mod.key = key
|
|
exec(code, mod.__dict__, mod.__dict__)
|
|
# another thread might set this first
|
|
cls.cache.setdefault(key, mod)
|
|
return cls.cache[key]
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def patch_triton_dir():
|
|
os.environ["TRITON_CACHE_DIR"] = os.environ.get(
|
|
"TRITON_CACHE_DIR", os.path.join(cache_dir(), "triton")
|
|
)
|
|
|
|
|
|
class TritonCodeCache:
|
|
@staticmethod
|
|
def get_name(mod):
|
|
(name,) = [n for n in dir(mod) if n.startswith("kernel")]
|
|
return name
|
|
|
|
@classmethod
|
|
def load(cls, source_code):
|
|
patch_triton_dir()
|
|
mod = PyCodeCache.load(source_code)
|
|
return getattr(mod, cls.get_name(mod))
|
|
|
|
|
|
def _worker_compile(source_code, cc, device):
|
|
cuda_properties.set_compiler_worker_current_device(device)
|
|
kernel = TritonCodeCache.load(source_code)
|
|
kernel.precompile(warm_cache_only_with_cc=cc)
|
|
|
|
|
|
def _load_kernel(source_code):
|
|
kernel = TritonCodeCache.load(source_code)
|
|
kernel.precompile()
|
|
return kernel
|
|
|
|
|
|
class TritonFuture:
|
|
def __init__(self, source_code, future):
|
|
self.source_code = source_code
|
|
self.future = future
|
|
|
|
def result(self):
|
|
if hasattr(self, "kernel"):
|
|
return self.kernel
|
|
# If the worker failed this will throw an exception.
|
|
self.future.result()
|
|
kernel = self.kernel = _load_kernel(self.source_code)
|
|
del self.source_code, self.future
|
|
return kernel
|
|
|
|
|
|
class AsyncCompile:
|
|
def __init__(self):
|
|
self._context_keepalive = None
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def pool():
|
|
assert config.compile_threads > 1
|
|
return ThreadPoolExecutor(config.compile_threads)
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(1)
|
|
def process_pool():
|
|
# ensure properties have been calculated before processes
|
|
# are forked
|
|
cuda_properties._properties()
|
|
assert config.compile_threads > 1
|
|
orig_ppid = os.getpid()
|
|
|
|
# if this process dies abnormally (e.g. segfault)
|
|
# it will not shut down the workers. Instead
|
|
# the workers will have their parent reassigned to the
|
|
# init process. This launches a separate thread to
|
|
# watch for the worker getting reassigned,
|
|
# and cleans it up in this case.
|
|
def init():
|
|
def run():
|
|
while True:
|
|
sleep(1)
|
|
if orig_ppid != os.getppid():
|
|
os.kill(os.getpid(), signal.SIGKILL)
|
|
|
|
global _watchdog_thread
|
|
_watchdog_thread = Thread(target=run, daemon=True)
|
|
_watchdog_thread.start()
|
|
|
|
# we rely on 'fork' because we cannot control whether users
|
|
# have an `if __name__ == '__main__'` in their main process.
|
|
fork_context = multiprocessing.get_context("fork")
|
|
pool = ProcessPoolExecutor(
|
|
config.compile_threads, mp_context=fork_context, initializer=init
|
|
)
|
|
# when this pool is created in a subprocess object, the normal exit handler
|
|
# doesn't run, and we need to register our own handler.
|
|
# exitpriority has to be high, because another one of the finalizers will
|
|
# kill the worker thread that sends the shutdown message to the workers...
|
|
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
|
|
return pool
|
|
|
|
@classmethod
|
|
def warm_pool(cls):
|
|
if config.compile_threads <= 1:
|
|
return
|
|
_compile_start()
|
|
pool = cls.process_pool()
|
|
|
|
# We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the
|
|
# slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread.
|
|
|
|
# Examples:
|
|
# A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup
|
|
# tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup
|
|
|
|
# So we want to start the workers early when it is still cheap, and also to allow the workers to get
|
|
# ready before we have work for them.
|
|
|
|
# ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle.
|
|
# But if we waited until then fork time will be long and we will be waiting for the processes to initialize.
|
|
|
|
# We force them to start here with some YOLOing of the internal methods.
|
|
if hasattr(pool, "_start_queue_management_thread"):
|
|
pool._start_queue_management_thread()
|
|
else:
|
|
for i in range(config.compile_threads):
|
|
pool._adjust_process_count()
|
|
pool._start_executor_manager_thread()
|
|
_compile_end()
|
|
|
|
@classmethod
|
|
def submit(cls, task):
|
|
if config.compile_threads <= 1:
|
|
return task()
|
|
return cls.pool().submit(task)
|
|
|
|
@classmethod
|
|
def map(cls, fn, seq):
|
|
if config.compile_threads <= 1 or len(seq) <= 1:
|
|
return list(map(fn, seq))
|
|
return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]
|
|
|
|
def triton(self, source_code):
|
|
_compile_start()
|
|
if self._context_keepalive is None:
|
|
# Workaround `CUDA: Error- context is destroyed`
|
|
self._context_keepalive = torch.tensor([1], device="cuda")
|
|
|
|
if config.compile_threads > 1:
|
|
major, minor = torch.cuda.get_device_capability()
|
|
device = torch.cuda.current_device()
|
|
cc = major * 10 + minor
|
|
future = self.process_pool().submit(
|
|
_worker_compile, source_code, cc, device
|
|
)
|
|
return TritonFuture(source_code, future)
|
|
else:
|
|
return _load_kernel(source_code)
|
|
|
|
def cpp(self, source_code):
|
|
def task():
|
|
return CppCodeCache.load(source_code).kernel
|
|
|
|
return self.submit(task)
|
|
|
|
def wait(self, scope: Dict[str, Any]):
|
|
if config.compile_threads > 1:
|
|
for key, result in list(scope.items()):
|
|
if isinstance(result, (Future, TritonFuture)):
|
|
scope[key] = result.result()
|
|
|
|
_compile_end()
|
|
|
|
|
|
AsyncCompile.warm_pool()
|