mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Explicit vectorization support for TorchInductor (#87068)
In this PR, we replace OMP SIMD with `aten::vec` to optimize TorchInductor vectorization performance. Take `res=torch.exp(torch.add(x, y))` as the example. The generated code is as follows if `config.cpp.simdlen` is 8.
```C++
extern "C" void kernel(const float* __restrict__ in_ptr0,
const float* __restrict__ in_ptr1,
float* __restrict__ out_ptr0,
const long ks0,
const long ks1)
{
#pragma omp parallel num_threads(48)
{
#pragma omp for
for(long i0=0; i0<((ks0*ks1) / 8); ++i0)
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 8*i0);
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 8*i0);
auto tmp2 = tmp0 + tmp1;
auto tmp3 = tmp2.exp();
tmp3.store(out_ptr0 + 8*i0);
}
#pragma omp for simd simdlen(4)
for(long i0=8*(((ks0*ks1) / 8)); i0<ks0*ks1; ++i0)
{
auto tmp0 = in_ptr0[i0];
auto tmp1 = in_ptr1[i0];
auto tmp2 = tmp0 + tmp1;
auto tmp3 = std::exp(tmp2);
out_ptr0[i0] = tmp3;
}
}
}
```
The major pipeline is as follows.
- Check whether the loop body could be vectorized by `aten::vec`. The checker consists of two parts. [One ](bf66991fc4/torch/_inductor/codegen/cpp.py (L702))is to check whether all the `ops` have been supported. The [other one](355326faa3/torch/_inductor/codegen/cpp.py (L672)) is to check whether the data access could be vectorized.
- [`CppSimdVecKernelChecker`](355326faa3/torch/_inductor/codegen/cpp.py (L655))
- Create the `aten::vec` kernel and original omp simd kernel. Regarding the original omp simd kernel, it serves for the tail loop when the loop is vectorized.
- [`CppSimdVecKernel`](355326faa3/torch/_inductor/codegen/cpp.py (L601))
- [`CppSimdVecOverrides`](355326faa3/torch/_inductor/codegen/cpp.py (L159)): The ops that we have supported on the top of `aten::vec`
- Create kernel
- [`aten::vec` kernel](355326faa3/torch/_inductor/codegen/cpp.py (L924))
- [`Original CPP kernel - OMP SIMD`](355326faa3/torch/_inductor/codegen/cpp.py (L929))
- Generate code
- [`CppKernelProxy`](355326faa3/torch/_inductor/codegen/cpp.py (L753)) is used to combine the `aten::vec` kernel and original cpp kernel
- [Vectorize the most inner loop](355326faa3/torch/_inductor/codegen/cpp.py (L753))
- [Generate code](355326faa3/torch/_inductor/codegen/cpp.py (L821))
Next steps:
- [x] Support reduction
- [x] Vectorize the tail loop with `aten::vec`
- [ ] Support BF16
- [ ] Optimize the loop condition and loop index calculation by replacing `div` with `add`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87068
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
a95419b47e
commit
6541e51ffd
1
setup.py
1
setup.py
@ -1036,6 +1036,7 @@ def main():
|
||||
'lib/*.pdb',
|
||||
'lib/torch_shm_manager',
|
||||
'lib/*.h',
|
||||
'include/*.h',
|
||||
'include/ATen/*.h',
|
||||
'include/ATen/cpu/*.h',
|
||||
'include/ATen/cpu/vec/vec256/*.h',
|
||||
|
||||
@ -37,7 +37,7 @@ try:
|
||||
import torch._inductor.config
|
||||
from functorch.compile import config as functorch_config
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._inductor import config
|
||||
from torch._inductor import codecache, config, metrics
|
||||
from torch._inductor.compile_fx import compile_fx, complex_memory_overlap
|
||||
from torch._inductor.ir import IndexingDiv, ModularIndexing
|
||||
from torch._inductor.sizevars import SizeVarAllocator
|
||||
@ -53,7 +53,6 @@ except (ImportError, AssertionError) as e:
|
||||
sys.exit(0)
|
||||
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
||||
|
||||
|
||||
HAS_CPU = False
|
||||
try:
|
||||
from subprocess import CalledProcessError
|
||||
@ -4416,6 +4415,75 @@ if HAS_CPU:
|
||||
self.assertFalse(complex_memory_overlap(gathered))
|
||||
self.assertFalse(complex_memory_overlap(gathered.t()))
|
||||
|
||||
# Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not
|
||||
# supported, the vectorization will not work and skip this test case. For ARM or
|
||||
# other platforms support, we just need to add the ISA info to the supported_vector_isa
|
||||
# and include proper aten vectorization head file.
|
||||
@unittest.skipIf(
|
||||
not codecache.get_cpu_proc_info(), "Does not support vectorization"
|
||||
)
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_vec_kernel_cpu_only(self):
|
||||
def fn(x1, x2):
|
||||
# Current, there are some limitations as follows.
|
||||
# rsqrt:
|
||||
# assert [both a fallback and a decomp for same kernel: aten.rsqrt.default]
|
||||
# round:
|
||||
# couldn't find symbolic meta function/decomposition
|
||||
# fmod/logical_and/logic_or:
|
||||
# vec kernel has not support to_type
|
||||
x = torch.abs(x1)
|
||||
x = torch.sin(x)
|
||||
x = torch.neg(x)
|
||||
x = torch.square(x)
|
||||
x = torch.sigmoid(x)
|
||||
x = torch.relu(x)
|
||||
x = torch.cos(x)
|
||||
x = torch.exp(x)
|
||||
x = torch.sqrt(x)
|
||||
x = torch.add(x, x1)
|
||||
x = torch.sub(x, x2)
|
||||
x = torch.mul(x, x1)
|
||||
x = torch.div(x, x1)
|
||||
x = torch.pow(x, 10)
|
||||
x = torch.log(x)
|
||||
x = torch.floor(x)
|
||||
x = torch.ceil(x)
|
||||
x = torch.trunc(x)
|
||||
x = torch.lgamma(x)
|
||||
x = torch.fmod(x, x2)
|
||||
res = x + x2
|
||||
return (res,)
|
||||
|
||||
x1 = torch.randn((10, 20))
|
||||
x2 = torch.randn((10, 20))
|
||||
|
||||
with patch.object(config.cpp, "simdlen", 8):
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
traced = make_fx(fn)(x1, x2)
|
||||
compiled = compile_fx_inner(traced, [x1, x2])
|
||||
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
x1 = x1.permute(1, 0)
|
||||
x2 = torch.randn((20, 10))
|
||||
traced = make_fx(fn)(x1, x2)
|
||||
compiled = compile_fx_inner(traced, [x1, x2])
|
||||
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
x1 = torch.randn((10, 7))
|
||||
x2 = torch.randn((10, 7))
|
||||
traced = make_fx(fn)(x1, x2)
|
||||
compiled = compile_fx_inner(traced, ([x1, x2]))
|
||||
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
|
||||
|
||||
if HAS_CUDA:
|
||||
import triton
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import enum
|
||||
import functools
|
||||
import getpass
|
||||
import hashlib
|
||||
@ -146,11 +147,81 @@ def is_gcc():
|
||||
return re.search(r"(gcc|g\+\+)", cpp_compiler())
|
||||
|
||||
|
||||
class _SupportedVecIsa(enum.Enum):
|
||||
AVX512 = 1
|
||||
AVX2 = 2
|
||||
INVALID = -1
|
||||
|
||||
def __bool__(self):
|
||||
return self != _SupportedVecIsa.INVALID
|
||||
|
||||
@staticmethod
|
||||
def isa_str(supported_isa: enum.Enum):
|
||||
if supported_isa == _SupportedVecIsa.AVX512:
|
||||
return "avx512"
|
||||
elif supported_isa == _SupportedVecIsa.AVX2:
|
||||
return "avx2"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def vec_macro(supported_isa: enum.Enum):
|
||||
if supported_isa == _SupportedVecIsa.AVX512:
|
||||
return "CPU_CAPABILITY_AVX512"
|
||||
elif supported_isa == _SupportedVecIsa.AVX2:
|
||||
return "CPU_CAPABILITY_AVX2"
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
||||
# might have too much redundant content that is useless for ISA check. Hence,
|
||||
# we only cache some key isa information.
|
||||
@functools.lru_cache(1)
|
||||
def get_cpu_proc_info():
|
||||
if sys.platform != "linux":
|
||||
return []
|
||||
|
||||
isa_info = []
|
||||
with open("/proc/cpuinfo") as _cpu_info:
|
||||
_cpu_info_content = _cpu_info.read()
|
||||
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX512) in _cpu_info_content:
|
||||
isa_info.append(_SupportedVecIsa.AVX512)
|
||||
|
||||
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX2) in _cpu_info_content:
|
||||
isa_info.append(_SupportedVecIsa.AVX2)
|
||||
|
||||
return isa_info
|
||||
|
||||
|
||||
def supported_vector_isa():
|
||||
# TODO: Add ARM Vec here.
|
||||
# Dict(k: isa, v: number of float element)
|
||||
vec_isa_info = {
|
||||
_SupportedVecIsa.AVX512: 16,
|
||||
_SupportedVecIsa.AVX2: 8,
|
||||
}
|
||||
|
||||
if config.cpp.simdlen is None or config.cpp.simdlen <= 1:
|
||||
return _SupportedVecIsa.INVALID
|
||||
|
||||
cpu_info_content = get_cpu_proc_info()
|
||||
for isa in vec_isa_info.keys():
|
||||
if isa in cpu_info_content and config.cpp.simdlen == vec_isa_info[isa]:
|
||||
return isa
|
||||
|
||||
return _SupportedVecIsa.INVALID
|
||||
|
||||
|
||||
def cpp_compile_command(input, output, include_pytorch=False):
|
||||
if include_pytorch:
|
||||
valid_isa = supported_vector_isa()
|
||||
if include_pytorch or valid_isa:
|
||||
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
|
||||
lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")]
|
||||
libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"]
|
||||
macros = _SupportedVecIsa.vec_macro(valid_isa)
|
||||
if macros:
|
||||
macros = f"-D{macros}"
|
||||
else:
|
||||
# Note - this is effectively a header only inclusion. Usage of some header files may result in
|
||||
# symbol not found, if those header files require a library.
|
||||
@ -159,17 +230,19 @@ def cpp_compile_command(input, output, include_pytorch=False):
|
||||
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
|
||||
lpaths = []
|
||||
libs = ["gomp"]
|
||||
macros = ""
|
||||
ipaths = " ".join(["-I" + p for p in ipaths])
|
||||
lpaths = " ".join(["-L" + p for p in lpaths])
|
||||
libs = " ".join(["-l" + p for p in libs])
|
||||
|
||||
return re.sub(
|
||||
r"[ \n]+",
|
||||
" ",
|
||||
f"""
|
||||
{cpp_compiler()} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable
|
||||
{ipaths} {lpaths} {libs}
|
||||
{cpp_compiler()} {input} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable
|
||||
{ipaths} {lpaths} {libs} {macros}
|
||||
-march=native -O3 -ffast-math -fno-finite-math-only -fopenmp
|
||||
-o{output} {input}
|
||||
-o{output}
|
||||
""",
|
||||
).strip()
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
@ -9,8 +10,9 @@ import sympy
|
||||
import torch
|
||||
from torch._prims_common import is_float_dtype
|
||||
|
||||
from .. import codecache, config
|
||||
from ..utils import sympy_product, sympy_symbol
|
||||
from .. import codecache, config, ir, metrics
|
||||
from ..codegen.wrapper import WrapperCodeGen
|
||||
from ..utils import sympy_product, sympy_subs, sympy_symbol
|
||||
from ..virtualized import ops, V
|
||||
from .common import (
|
||||
BracesBuffer,
|
||||
@ -120,6 +122,13 @@ def float16_reduction_prefix(rtype):
|
||||
return prefix
|
||||
|
||||
|
||||
def parallel_num_threads():
|
||||
threads = config.cpp.threads
|
||||
if threads < 1:
|
||||
threads = torch.get_num_threads()
|
||||
return threads
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def cpp_prefix():
|
||||
path = Path(__file__).parent / "cpp_prefix.h"
|
||||
@ -151,6 +160,135 @@ class CppPrinter(ExprPrinter):
|
||||
cexpr = CppPrinter().doprint
|
||||
|
||||
|
||||
class CppVecOverrides(OpOverrides):
|
||||
"""Map element-wise ops to aten vectorization C++"""
|
||||
|
||||
@staticmethod
|
||||
def add(a, b):
|
||||
return f"{a} + {b}"
|
||||
|
||||
@staticmethod
|
||||
def sub(a, b):
|
||||
return f"{a} - {b}"
|
||||
|
||||
@staticmethod
|
||||
def mul(a, b):
|
||||
return f"{a} * {b}"
|
||||
|
||||
@staticmethod
|
||||
def div(a, b):
|
||||
return f"{a} / {b}"
|
||||
|
||||
@staticmethod
|
||||
def abs(x):
|
||||
return f"{x}.abs()"
|
||||
|
||||
@staticmethod
|
||||
def sin(x):
|
||||
return f"{x}.sin()"
|
||||
|
||||
@staticmethod
|
||||
def cos(x):
|
||||
return f"{x}.cos()"
|
||||
|
||||
@staticmethod
|
||||
def exp(x):
|
||||
return f"{x}.exp()"
|
||||
|
||||
@staticmethod
|
||||
def sqrt(x):
|
||||
return f"{x}.sqrt()"
|
||||
|
||||
@staticmethod
|
||||
def rsqrt(x):
|
||||
return f"{x}.rsqrt()"
|
||||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
return f"{a}.pow({b})"
|
||||
|
||||
@staticmethod
|
||||
def log(x):
|
||||
return f"{x}.log()"
|
||||
|
||||
@staticmethod
|
||||
def round(x):
|
||||
return f"{x}.round()"
|
||||
|
||||
@staticmethod
|
||||
def floor(x):
|
||||
return f"{x}.floor()"
|
||||
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
return f"{x}.ceil()"
|
||||
|
||||
@staticmethod
|
||||
def trunc(x):
|
||||
return f"{x}.trunc()"
|
||||
|
||||
@staticmethod
|
||||
def fmod(a, b):
|
||||
return f"{a}.fmod({b})"
|
||||
|
||||
@staticmethod
|
||||
def lgamma(x):
|
||||
return f"{x}.lgamma()"
|
||||
|
||||
@staticmethod
|
||||
def logical_and(a, b):
|
||||
return f"{a} && {b}"
|
||||
|
||||
@staticmethod
|
||||
def logical_or(a, b):
|
||||
return f"{a} || {b}"
|
||||
|
||||
@staticmethod
|
||||
def tanh(a):
|
||||
return f"{a}.tanh()"
|
||||
|
||||
@staticmethod
|
||||
def reciprocal(a):
|
||||
return f"{a}.reciprocal()"
|
||||
|
||||
@staticmethod
|
||||
def constant(val, dtype):
|
||||
if val == float("inf"):
|
||||
quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
|
||||
elif val == float("-inf"):
|
||||
quote = f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
|
||||
elif val is True or val is False:
|
||||
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({str(val).lower()})"
|
||||
else:
|
||||
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({repr(val)})"
|
||||
return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({quote})"
|
||||
|
||||
@staticmethod
|
||||
def relu(x):
|
||||
return f"at::vec::clamp_min({x}, decltype({x})(0))"
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(x):
|
||||
return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())"
|
||||
|
||||
@staticmethod
|
||||
def neg(x):
|
||||
return f"{x}.neg()"
|
||||
|
||||
@staticmethod
|
||||
def floordiv(a, b):
|
||||
# a and b are integer type
|
||||
_t = f"decltype({a})"
|
||||
quot = f"{a} / {b}"
|
||||
rem = f"{a} % {b}"
|
||||
return f"(({a} < {_t}(0)) != ({b} < {_t}(0)) ? ({rem} != {_t}(0) ? {quot} - {_t}(1) : {quot}) : {quot})"
|
||||
|
||||
@staticmethod
|
||||
def truncdiv(a, b):
|
||||
# a and b are integer type
|
||||
return f"{a} / {b}"
|
||||
|
||||
|
||||
class CppOverrides(OpOverrides):
|
||||
"""Map element-wise ops to C++"""
|
||||
|
||||
@ -413,9 +551,7 @@ class CppKernel(Kernel):
|
||||
return V.graph.sizevars.size_hint(sympy_product(self.call_ranges))
|
||||
|
||||
def codegen_loops(self, code, worksharing):
|
||||
threads = config.cpp.threads
|
||||
if threads < 1:
|
||||
threads = torch.get_num_threads()
|
||||
threads = parallel_num_threads()
|
||||
|
||||
loops = [LoopLevel(var, size) for var, size in zip(self.itervars, self.ranges)]
|
||||
loops, reductions = LoopNest(loops[: self.reduction_depth]), LoopNest(
|
||||
@ -427,7 +563,7 @@ class CppKernel(Kernel):
|
||||
# TODO(jansel): detect stride-1 dimension and vectorize that
|
||||
if reductions:
|
||||
reductions.loops[-1].simd = True
|
||||
else:
|
||||
elif loops:
|
||||
loops.loops[-1].simd = True
|
||||
|
||||
par_depth = 0
|
||||
@ -509,6 +645,265 @@ class CppKernel(Kernel):
|
||||
(self.loads, self.compute, self.stores, self.cse) = prior
|
||||
|
||||
|
||||
class CppVecKernel(CppKernel):
|
||||
overrides = CppVecOverrides
|
||||
|
||||
def __init__(self, args, num_threads):
|
||||
super(CppVecKernel, self).__init__(args, num_threads)
|
||||
self.simd_len = config.cpp.simdlen
|
||||
metrics.generated_cpp_vec_kernel_count += 1
|
||||
|
||||
def is_single_step_var(self, var: sympy.Symbol, index: sympy.Expr):
|
||||
replacement = {var: var + 1}
|
||||
new_index = sympy_subs(index, replacement)
|
||||
delta = sympy.simplify(new_index - index)
|
||||
return delta == 1
|
||||
|
||||
def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr):
|
||||
expanded_index = sympy.expand(index)
|
||||
return not expanded_index.has(var)
|
||||
|
||||
def transform_index(self, index: sympy.Expr):
|
||||
expanded_index = sympy.expand(index)
|
||||
assert self.simd_len
|
||||
assert self.simd_len > 0
|
||||
most_inner_var = self.itervars[-1]
|
||||
replacement = {most_inner_var: most_inner_var * self.simd_len}
|
||||
new_index = sympy_subs(expanded_index, replacement)
|
||||
return new_index
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
var = self.args.input(name)
|
||||
index = self.rename_indexing(index)
|
||||
|
||||
expanded_index = sympy.expand(index)
|
||||
new_index = self.transform_index(index)
|
||||
|
||||
if expanded_index == new_index:
|
||||
line = f"at::vec::Vectorized<float>({var}[{cexpr(index)}])"
|
||||
else:
|
||||
line = f"at::vec::Vectorized<float>::loadu({var} + {cexpr(new_index)})"
|
||||
|
||||
return self.cse.generate(self.loads, line)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
assert "buf" in name
|
||||
var = self.args.output(name)
|
||||
index = self.rename_indexing(index)
|
||||
assert mode is None
|
||||
|
||||
expanded_index = sympy.expand(index)
|
||||
new_index = self.transform_index(index)
|
||||
assert new_index != expanded_index
|
||||
line = f"{value}.store({var} + {cexpr(new_index)});"
|
||||
self.stores.writeline(name, line)
|
||||
|
||||
|
||||
class CppVecKernelChecker(CppVecKernel):
|
||||
def __init__(self, args, num_threads):
|
||||
super(CppVecKernelChecker, self).__init__(args, num_threads)
|
||||
|
||||
# Since this kernel is only for checker but does not genreate any
|
||||
# code, so we need to decrease the kernel count.
|
||||
metrics.generated_kernel_count -= 1
|
||||
metrics.generated_cpp_vec_kernel_count -= 1
|
||||
|
||||
# Used to recorde the graph wrapper code as the wrapper_code status could be
|
||||
# changed during graph run.
|
||||
self._orig_wrapper_code = None
|
||||
|
||||
self.simd_vec = True
|
||||
self.fast_vec_list = []
|
||||
for k, v in CppVecOverrides.__dict__.items():
|
||||
if isinstance(v, staticmethod):
|
||||
self.fast_vec_list.append(k)
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
|
||||
def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr):
|
||||
return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index)
|
||||
|
||||
def could_vec(self, name: str, index: sympy.Expr):
|
||||
if V.graph.get_dtype(name) is not torch.float:
|
||||
return False
|
||||
|
||||
assert self.itervars is not None
|
||||
# Not a loop
|
||||
if len(self.itervars) == 0:
|
||||
return False
|
||||
|
||||
most_inner_var = self.itervars[-1]
|
||||
return self.is_legal_data_access(most_inner_var, index)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = self.rename_indexing(index)
|
||||
|
||||
self.simd_vec = self.simd_vec and self.could_vec(name, index)
|
||||
return self.simd_vec
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
assert "buf" in name
|
||||
index = self.rename_indexing(index)
|
||||
|
||||
if mode:
|
||||
self.simd_vec = False
|
||||
return False
|
||||
|
||||
self.simd_vec = self.simd_vec and self.could_vec(name, index)
|
||||
return self.simd_vec
|
||||
|
||||
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
||||
self.simd_vec = False
|
||||
return self.simd_vec
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
assert self._orig_wrapper_code is not None
|
||||
# Restore the wrapper_code
|
||||
V.graph.wrapper_code = self._orig_wrapper_code
|
||||
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __enter__(self):
|
||||
# Recorde the graph wrapper code. The wrapper_code status could be
|
||||
# changed during graph run. Regarding this checker, we also need to
|
||||
# run the graph but we don't expect to change any status that would
|
||||
# impact the code generation. Hence, we record the graph wapper code
|
||||
# and replace it with a dummy warpper_code and then restore to the
|
||||
# original one as long as the checker is finished.
|
||||
self._orig_wrapper_code = V.graph.wrapper_code
|
||||
V.graph.wrapper_code = WrapperCodeGen()
|
||||
|
||||
class VecCheckerProxy:
|
||||
@staticmethod
|
||||
def __getattr__(name):
|
||||
def inner(*args, **kwargs):
|
||||
if not (name in self.fast_vec_list):
|
||||
self.simd_vec = False
|
||||
return self.simd_vec
|
||||
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def load(name: str, index: sympy.Expr):
|
||||
return self.load(name, index)
|
||||
|
||||
@staticmethod
|
||||
def store(name, index, value, mode=None):
|
||||
return self.store(name, index, value, mode=mode)
|
||||
|
||||
@staticmethod
|
||||
def reduction(name, dtype, src_dtype, reduction_type, index, value):
|
||||
return self.reduction(
|
||||
name, dtype, src_dtype, reduction_type, index, value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def constant(val, dtype):
|
||||
supported_dtype = (torch.float32, torch.int32)
|
||||
is_supported_dtype = dtype in (supported_dtype)
|
||||
if not is_supported_dtype:
|
||||
self.simd_vec = False
|
||||
return is_supported_dtype
|
||||
|
||||
@staticmethod
|
||||
def index_expr(expr, dtype):
|
||||
self.simd_vec = False
|
||||
return self.cse.newvar()
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(index_var):
|
||||
return sympy.Symbol(str(index_var))
|
||||
|
||||
@staticmethod
|
||||
def masked(mask, body, other):
|
||||
return V.kernel.cse.newvar()
|
||||
|
||||
self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy()))
|
||||
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
||||
return self
|
||||
|
||||
|
||||
class CppKernelProxy(CppKernel):
|
||||
def __init__(self, args=None, num_threads=None):
|
||||
super(CppKernelProxy, self).__init__(args, num_threads)
|
||||
self.simd_vec_kernel = None
|
||||
self.simd_omp_kernel = None
|
||||
|
||||
def vectorize_most_inner_loop(self, loop_nest):
|
||||
loop_nest.split_most_inner_loop(config.cpp.simdlen)
|
||||
loop_with_tail = loop_nest.loops[-1]
|
||||
assert isinstance(loop_with_tail, LoopLevelWithTail)
|
||||
|
||||
self.simd_vec_kernel.simd = False
|
||||
self.simd_vec_kernel.fast_vec = True
|
||||
|
||||
loop_with_tail.tail_loop.simd_omp = True
|
||||
# We chope the loop into two cubes by the config.cpp.simdlen - main loop and tail loop.
|
||||
# Regarding the main loop, it is straightforward that it could be vectorized with
|
||||
# config.cpp.simdlen. But for the tail loop, it still could be vectorized. For example,
|
||||
# if the config.cpp.simdlen is 8(256bits), then the tail loop still could be vectorized
|
||||
# as 4(128bits).
|
||||
loop_with_tail.tail_loop.simd_len = int(config.cpp.simdlen / 2)
|
||||
loop_with_tail.tail_loop.simd_vec = False
|
||||
|
||||
loop_with_tail.main_loop_body = self.simd_vec_kernel
|
||||
loop_with_tail.tail_loop_body = self.simd_omp_kernel
|
||||
return loop_nest
|
||||
|
||||
def codegen_loops(self, code, worksharing):
|
||||
threads = parallel_num_threads()
|
||||
|
||||
if self.simd_vec_kernel is None:
|
||||
assert self.simd_omp_kernel
|
||||
return self.simd_omp_kernel.codegen_loops(code, worksharing)
|
||||
|
||||
assert self.simd_vec_kernel.itervars == self.simd_omp_kernel.itervars
|
||||
assert self.simd_vec_kernel.ranges == self.simd_omp_kernel.ranges
|
||||
|
||||
itervars = self.simd_vec_kernel.itervars
|
||||
rangs = self.simd_vec_kernel.ranges
|
||||
loops = [LoopLevel(var, size) for var, size in zip(itervars, rangs)]
|
||||
|
||||
# TODO: Support reductions
|
||||
loops_nest_non_reduc, _ = LoopNest(loops[: self.reduction_depth]), LoopNest(
|
||||
loops[self.reduction_depth :]
|
||||
)
|
||||
|
||||
assert config.cpp.simdlen
|
||||
loops_nest_non_reduc.loops[-1].simd_omp = True
|
||||
|
||||
par_depth = 0
|
||||
if loops_nest_non_reduc:
|
||||
par_depth = self.simd_vec_kernel.decide_parallel_depth(
|
||||
self.simd_vec_kernel.call_ranges[: self.reduction_depth], threads
|
||||
)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if par_depth:
|
||||
worksharing.parallel(threads)
|
||||
loops_nest_non_reduc.mark_parallel(par_depth)
|
||||
elif threads > 1:
|
||||
if worksharing.single():
|
||||
stack.enter_context(code.indent())
|
||||
|
||||
self.vectorize_most_inner_loop(loops_nest_non_reduc)
|
||||
|
||||
for loop in loops_nest_non_reduc.loops[0:-1]:
|
||||
code.writelines(loop.lines())
|
||||
stack.enter_context(code.indent())
|
||||
|
||||
loop_with_tail: LoopLevelWithTail = loops_nest_non_reduc.loops[-1]
|
||||
for loop, kernel in (
|
||||
(loop_with_tail.main_loop, loop_with_tail.main_loop_body),
|
||||
(loop_with_tail.tail_loop, loop_with_tail.tail_loop_body),
|
||||
):
|
||||
|
||||
code.writelines(loop.lines())
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(code.indent())
|
||||
code.splice(kernel.loads)
|
||||
code.splice(kernel.compute)
|
||||
code.splice(kernel.stores)
|
||||
|
||||
|
||||
class CppScheduling:
|
||||
def __init__(self, scheduler):
|
||||
self.scheduler = scheduler
|
||||
@ -532,38 +927,113 @@ class CppScheduling:
|
||||
def can_fuse_vertical(cls, node1, node2):
|
||||
return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction()
|
||||
|
||||
def codegen_nodes(self, nodes):
|
||||
"""
|
||||
Turn an set of pre-fused nodes into a C++ kernel.
|
||||
"""
|
||||
kernel_group = self.kernel_group
|
||||
scheduler = self.scheduler
|
||||
def can_vec(self, nodes):
|
||||
# TODO: Query cpu arch and vec length from aten
|
||||
if not codecache.supported_vector_isa():
|
||||
return False
|
||||
|
||||
_, (group, reduction_group) = max(
|
||||
nodes, key=lambda x: int(x.is_reduction())
|
||||
).group
|
||||
in_suffix = False
|
||||
|
||||
with kernel_group.new_kernel() as kernel:
|
||||
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
|
||||
|
||||
with CppVecKernelChecker(
|
||||
deepcopy(self.kernel_group.args), parallel_num_threads()
|
||||
) as kernel_checker:
|
||||
vars, reduction_vars = kernel_checker.set_ranges(group, reduction_group)
|
||||
for node in nodes:
|
||||
if node.group[1] in [
|
||||
(group, reduction_group),
|
||||
(group + reduction_group, ()),
|
||||
]:
|
||||
assert not in_suffix
|
||||
node.run(vars, reduction_vars)
|
||||
else:
|
||||
in_suffix = True
|
||||
assert node.group[1] == (
|
||||
group,
|
||||
(),
|
||||
), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
|
||||
# we can fuse in some extra pointwise into the suffix
|
||||
with kernel.write_to_suffix():
|
||||
node.run(vars, ())
|
||||
node.run(vars, ())
|
||||
|
||||
kernel_group.finalize_kernel(kernel, scheduler)
|
||||
return kernel_checker.simd_vec
|
||||
|
||||
def _codegen_nodes_impl(self, nodes, is_simd_vec=False):
|
||||
"""
|
||||
Turn an set of pre-fused nodes into a C++ kernel.
|
||||
"""
|
||||
kernel_group = self.kernel_group
|
||||
_, (group, reduction_group) = max(
|
||||
nodes, key=lambda x: int(x.is_reduction())
|
||||
).group
|
||||
|
||||
def create_kernel(_is_simd_vec):
|
||||
in_suffix = False
|
||||
|
||||
with kernel_group.new_kernel(_is_simd_vec) as kernel:
|
||||
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
|
||||
|
||||
for node in nodes:
|
||||
if node.group[1] in [
|
||||
(group, reduction_group),
|
||||
(group + reduction_group, ()),
|
||||
]:
|
||||
assert not in_suffix
|
||||
node.run(vars, reduction_vars)
|
||||
else:
|
||||
in_suffix = True
|
||||
assert node.group[1] == (
|
||||
group,
|
||||
(),
|
||||
), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
|
||||
# we can fuse in some extra pointwise into the suffix
|
||||
with kernel.write_to_suffix():
|
||||
node.run(vars, ())
|
||||
return kernel
|
||||
|
||||
org_inplace_buffers_flag = config.inplace_buffers
|
||||
if is_simd_vec:
|
||||
# Create vectorization kernel
|
||||
cpp_vec_kernel = create_kernel(True)
|
||||
|
||||
# Since a kernel is divided into two parts - vectorization and non-vectorization.
|
||||
# And the two parts share the same global contexts like V.graph.wrapper_code,
|
||||
# V.kernel.args. But the vectorization kernel generation has updated these global
|
||||
# contexts. Hence, the non-vectorization kernel should not do this again to avoid
|
||||
# conext conflict. By now, we only control the config.inplace_buffers. In the future,
|
||||
# we could maintain more contexts.
|
||||
config.inplace_buffers = False
|
||||
|
||||
# Create non-vectorization kernel
|
||||
cpp_kernel = create_kernel(False)
|
||||
|
||||
# Restore the inplace_buffers flag
|
||||
config.inplace_buffers = org_inplace_buffers_flag
|
||||
return (cpp_vec_kernel, cpp_kernel)
|
||||
else:
|
||||
return (None, create_kernel(False))
|
||||
|
||||
def codegen_nodes(self, nodes):
|
||||
"""
|
||||
Turn an set of pre-fused nodes into a C++ kernel.
|
||||
"""
|
||||
kernel_group = self.kernel_group
|
||||
|
||||
can_be_simd_vec = self.can_vec(nodes)
|
||||
simd_vec_kernel, simd_omp_kernel = self._codegen_nodes_impl(
|
||||
nodes, can_be_simd_vec
|
||||
)
|
||||
|
||||
assert simd_omp_kernel
|
||||
metrics.generated_kernel_count -= 1
|
||||
# Maitain the metrics kernel count
|
||||
if simd_vec_kernel:
|
||||
metrics.generated_kernel_count -= 1
|
||||
|
||||
cpp_kernel_proxy = CppKernelProxy(
|
||||
kernel_group.args, kernel_group.ws.num_threads
|
||||
)
|
||||
cpp_kernel_proxy.simd_vec_kernel = simd_vec_kernel
|
||||
cpp_kernel_proxy.simd_omp_kernel = simd_omp_kernel
|
||||
|
||||
kernel_group.finalize_kernel(cpp_kernel_proxy, None)
|
||||
|
||||
def flush(self):
|
||||
self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
|
||||
@ -580,8 +1050,11 @@ class KernelGroup:
|
||||
self.stack.enter_context(self.ws)
|
||||
self.count = 0
|
||||
|
||||
def new_kernel(self):
|
||||
return CppKernel(self.args, self.ws.num_threads)
|
||||
def new_kernel(self, simd_vec=False):
|
||||
if simd_vec:
|
||||
return CppVecKernel(self.args, parallel_num_threads())
|
||||
else:
|
||||
return CppKernel(self.args, parallel_num_threads())
|
||||
|
||||
def finalize_kernel(self, new_kernel, scheduler):
|
||||
self.count += 1
|
||||
@ -660,10 +1133,14 @@ class WorkSharing:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoopLevel:
|
||||
var: sympy.Expr
|
||||
size: sympy.Expr
|
||||
var: sympy.Expr = None
|
||||
size: sympy.Expr = None
|
||||
offset: sympy.Expr = sympy.Integer(0)
|
||||
steps: sympy.Expr = sympy.Integer(1)
|
||||
parallel: int = 0
|
||||
simd: bool = False
|
||||
simd_omp: bool = False
|
||||
simd_len: int = config.cpp.simdlen
|
||||
simd_vec: bool = False
|
||||
collapsed: bool = False
|
||||
reduction_vars: Dict[str, str] = None
|
||||
|
||||
@ -675,26 +1152,40 @@ class LoopLevel:
|
||||
)
|
||||
else:
|
||||
reduction = ""
|
||||
simd = f"simd simdlen({config.cpp.simdlen})"
|
||||
simd = f"simd simdlen({self.simd_len})" if self.simd_omp else ""
|
||||
if self.parallel:
|
||||
# TODO(jansel): look into chunk size and other schedules
|
||||
line1 = f"#pragma omp for{reduction} "
|
||||
if self.parallel > 1:
|
||||
line1 += f" collapse({self.parallel})"
|
||||
if self.simd:
|
||||
if self.simd_omp:
|
||||
line1 = line1.replace(" for ", f" for {simd}")
|
||||
elif self.simd:
|
||||
elif self.simd_vec:
|
||||
line1 = ""
|
||||
elif self.simd_omp:
|
||||
line1 = f"#pragma omp {simd}{reduction}"
|
||||
elif not self.reduction_vars and codecache.is_gcc():
|
||||
line1 = "#pragma GCC ivdep"
|
||||
else:
|
||||
line1 = ""
|
||||
line2 = f"for({INDEX_TYPE} {self.var}=0; {self.var}<{cexpr(self.size)}; ++{self.var})"
|
||||
line2 = f"for({INDEX_TYPE} {self.var}={cexpr(self.offset)}; {self.var}<{cexpr(self.size)}; {self.var}+={cexpr(self.steps)})"
|
||||
if self.collapsed or not line1:
|
||||
return [line2]
|
||||
return [line1, line2]
|
||||
|
||||
|
||||
class LoopLevelWithTail(LoopLevel):
|
||||
def __init__(self, main_loop: LoopLevel, tail_loop: LoopLevel):
|
||||
super().__init__()
|
||||
self.main_loop = main_loop
|
||||
self.tail_loop = tail_loop
|
||||
self.main_loop_body = None
|
||||
self.tail_loop_body = None
|
||||
|
||||
def lines(self):
|
||||
raise AssertionError("Not Implemented")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoopNest:
|
||||
loops: List[LoopLevel]
|
||||
@ -711,7 +1202,35 @@ class LoopNest:
|
||||
loops[0].parallel = par_depth
|
||||
for i in range(1, par_depth):
|
||||
loops[i].collapsed = True
|
||||
loops[0].simd = loops[par_depth - 1].simd
|
||||
|
||||
def split_most_inner_loop(self, factor):
|
||||
sympy_factor = sympy.Integer(factor)
|
||||
|
||||
most_inner_loop = self.loops[-1]
|
||||
|
||||
# If the most inner loop needs to be collapsed, we need to
|
||||
# exclude it since we need to split it into two loops. Meanwhile,
|
||||
# we still mark it as parallized.
|
||||
if most_inner_loop.collapsed:
|
||||
assert self.loops[0].parallel == len(self.loops)
|
||||
self.loops[0].parallel -= 1
|
||||
|
||||
main_loop_range = ir.IndexingDiv(most_inner_loop.size, sympy_factor)
|
||||
|
||||
main_loop = LoopLevel(most_inner_loop.var, main_loop_range)
|
||||
main_loop.parallel = 1 if most_inner_loop.parallel > 0 else 0
|
||||
main_loop.collapsed = False
|
||||
|
||||
offset = main_loop_range * sympy_factor
|
||||
tail_loop = LoopLevel(most_inner_loop.var, most_inner_loop.size)
|
||||
tail_loop.offset = offset
|
||||
tail_loop.parallel = 1 if most_inner_loop.parallel > 0 else 0
|
||||
tail_loop.collapsed = False
|
||||
|
||||
loop_with_tail = LoopLevelWithTail(main_loop, tail_loop)
|
||||
loop_with_tail.parallel = 0
|
||||
loop_with_tail.collapsed = False
|
||||
self.loops[-1] = loop_with_tail
|
||||
|
||||
def codegen(self, code, stack):
|
||||
for loop in self.loops:
|
||||
|
||||
@ -6,8 +6,11 @@
|
||||
#include <omp.h>
|
||||
|
||||
#include <ATen/core/PhiloxRNGEngine.h>
|
||||
#include <c10/util/Half.h>
|
||||
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#endif
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
typedef at::Half half;
|
||||
typedef at::BFloat16 bfloat16;
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
# counter for tracking how many kernels have been generated
|
||||
generated_kernel_count = 0
|
||||
generated_cpp_vec_kernel_count = 0
|
||||
|
||||
|
||||
# reset all counters
|
||||
def reset():
|
||||
global generated_kernel_count
|
||||
global generated_cpp_vec_kernel_count
|
||||
|
||||
generated_kernel_count = 0
|
||||
generated_cpp_vec_kernel_count = 0
|
||||
|
||||
Reference in New Issue
Block a user