Explicit vectorization support for TorchInductor (#87068)

In this PR, we replace OMP SIMD with `aten::vec` to optimize TorchInductor vectorization performance. Take `res=torch.exp(torch.add(x, y))` as the example. The generated code is as follows if `config.cpp.simdlen` is 8.

```C++
extern "C" void kernel(const float* __restrict__ in_ptr0,
                       const float* __restrict__ in_ptr1,
                       float* __restrict__ out_ptr0,
                       const long ks0,
                       const long ks1)
{
    #pragma omp parallel num_threads(48)
    {
        #pragma omp for
        for(long i0=0; i0<((ks0*ks1) / 8); ++i0)
        {
            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + 8*i0);
            auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + 8*i0);
            auto tmp2 = tmp0 + tmp1;
            auto tmp3 = tmp2.exp();
            tmp3.store(out_ptr0 + 8*i0);
        }
        #pragma omp for simd simdlen(4)
        for(long i0=8*(((ks0*ks1) / 8)); i0<ks0*ks1; ++i0)
        {
            auto tmp0 = in_ptr0[i0];
            auto tmp1 = in_ptr1[i0];
            auto tmp2 = tmp0 + tmp1;
            auto tmp3 = std::exp(tmp2);
            out_ptr0[i0] = tmp3;
        }
    }
}

```

The major pipeline is as follows.
- Check whether the loop body could be vectorized by `aten::vec`. The checker consists of two parts. [One ](bf66991fc4/torch/_inductor/codegen/cpp.py (L702))is to check whether all the `ops` have been supported. The [other one](355326faa3/torch/_inductor/codegen/cpp.py (L672)) is to check whether the data access could be vectorized.
  - [`CppSimdVecKernelChecker`](355326faa3/torch/_inductor/codegen/cpp.py (L655))
- Create the `aten::vec` kernel and original omp simd kernel. Regarding the original omp simd kernel, it serves for the tail loop when the loop is vectorized.
  - [`CppSimdVecKernel`](355326faa3/torch/_inductor/codegen/cpp.py (L601))
  - [`CppSimdVecOverrides`](355326faa3/torch/_inductor/codegen/cpp.py (L159)): The ops that we have supported on the top of `aten::vec`
  - Create kernel
    - [`aten::vec` kernel](355326faa3/torch/_inductor/codegen/cpp.py (L924))
    - [`Original CPP kernel - OMP SIMD`](355326faa3/torch/_inductor/codegen/cpp.py (L929))
- Generate code
  - [`CppKernelProxy`](355326faa3/torch/_inductor/codegen/cpp.py (L753)) is used to combine the `aten::vec` kernel and original cpp kernel
    - [Vectorize the most inner loop](355326faa3/torch/_inductor/codegen/cpp.py (L753))
    - [Generate code](355326faa3/torch/_inductor/codegen/cpp.py (L821))

Next steps:
- [x] Support reduction
- [x] Vectorize the tail loop with `aten::vec`
- [ ] Support BF16
- [ ] Optimize the loop condition and loop index calculation by replacing `div` with `add`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87068
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
Wang, Eikan
2022-11-04 05:28:15 +00:00
committed by PyTorch MergeBot
parent a95419b47e
commit 6541e51ffd
6 changed files with 707 additions and 39 deletions

View File

@ -1036,6 +1036,7 @@ def main():
'lib/*.pdb',
'lib/torch_shm_manager',
'lib/*.h',
'include/*.h',
'include/ATen/*.h',
'include/ATen/cpu/*.h',
'include/ATen/cpu/vec/vec256/*.h',

View File

@ -37,7 +37,7 @@ try:
import torch._inductor.config
from functorch.compile import config as functorch_config
from torch._decomp import get_decompositions
from torch._inductor import config
from torch._inductor import codecache, config, metrics
from torch._inductor.compile_fx import compile_fx, complex_memory_overlap
from torch._inductor.ir import IndexingDiv, ModularIndexing
from torch._inductor.sizevars import SizeVarAllocator
@ -53,7 +53,6 @@ except (ImportError, AssertionError) as e:
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
HAS_CPU = False
try:
from subprocess import CalledProcessError
@ -4416,6 +4415,75 @@ if HAS_CPU:
self.assertFalse(complex_memory_overlap(gathered))
self.assertFalse(complex_memory_overlap(gathered.t()))
# Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not
# supported, the vectorization will not work and skip this test case. For ARM or
# other platforms support, we just need to add the ISA info to the supported_vector_isa
# and include proper aten vectorization head file.
@unittest.skipIf(
not codecache.get_cpu_proc_info(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_vec_kernel_cpu_only(self):
def fn(x1, x2):
# Current, there are some limitations as follows.
# rsqrt:
# assert [both a fallback and a decomp for same kernel: aten.rsqrt.default]
# round:
# couldn't find symbolic meta function/decomposition
# fmod/logical_and/logic_or:
# vec kernel has not support to_type
x = torch.abs(x1)
x = torch.sin(x)
x = torch.neg(x)
x = torch.square(x)
x = torch.sigmoid(x)
x = torch.relu(x)
x = torch.cos(x)
x = torch.exp(x)
x = torch.sqrt(x)
x = torch.add(x, x1)
x = torch.sub(x, x2)
x = torch.mul(x, x1)
x = torch.div(x, x1)
x = torch.pow(x, 10)
x = torch.log(x)
x = torch.floor(x)
x = torch.ceil(x)
x = torch.trunc(x)
x = torch.lgamma(x)
x = torch.fmod(x, x2)
res = x + x2
return (res,)
x1 = torch.randn((10, 20))
x2 = torch.randn((10, 20))
with patch.object(config.cpp, "simdlen", 8):
torch._dynamo.reset()
metrics.reset()
traced = make_fx(fn)(x1, x2)
compiled = compile_fx_inner(traced, [x1, x2])
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
assert metrics.generated_cpp_vec_kernel_count == 1
torch._dynamo.reset()
metrics.reset()
x1 = x1.permute(1, 0)
x2 = torch.randn((20, 10))
traced = make_fx(fn)(x1, x2)
compiled = compile_fx_inner(traced, [x1, x2])
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
assert metrics.generated_cpp_vec_kernel_count == 1
torch._dynamo.reset()
metrics.reset()
x1 = torch.randn((10, 7))
x2 = torch.randn((10, 7))
traced = make_fx(fn)(x1, x2)
compiled = compile_fx_inner(traced, ([x1, x2]))
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
assert metrics.generated_cpp_vec_kernel_count == 1
if HAS_CUDA:
import triton

View File

@ -1,4 +1,5 @@
import base64
import enum
import functools
import getpass
import hashlib
@ -146,11 +147,81 @@ def is_gcc():
return re.search(r"(gcc|g\+\+)", cpp_compiler())
class _SupportedVecIsa(enum.Enum):
AVX512 = 1
AVX2 = 2
INVALID = -1
def __bool__(self):
return self != _SupportedVecIsa.INVALID
@staticmethod
def isa_str(supported_isa: enum.Enum):
if supported_isa == _SupportedVecIsa.AVX512:
return "avx512"
elif supported_isa == _SupportedVecIsa.AVX2:
return "avx2"
else:
return ""
@staticmethod
def vec_macro(supported_isa: enum.Enum):
if supported_isa == _SupportedVecIsa.AVX512:
return "CPU_CAPABILITY_AVX512"
elif supported_isa == _SupportedVecIsa.AVX2:
return "CPU_CAPABILITY_AVX2"
else:
return ""
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
# might have too much redundant content that is useless for ISA check. Hence,
# we only cache some key isa information.
@functools.lru_cache(1)
def get_cpu_proc_info():
if sys.platform != "linux":
return []
isa_info = []
with open("/proc/cpuinfo") as _cpu_info:
_cpu_info_content = _cpu_info.read()
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX512) in _cpu_info_content:
isa_info.append(_SupportedVecIsa.AVX512)
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX2) in _cpu_info_content:
isa_info.append(_SupportedVecIsa.AVX2)
return isa_info
def supported_vector_isa():
# TODO: Add ARM Vec here.
# Dict(k: isa, v: number of float element)
vec_isa_info = {
_SupportedVecIsa.AVX512: 16,
_SupportedVecIsa.AVX2: 8,
}
if config.cpp.simdlen is None or config.cpp.simdlen <= 1:
return _SupportedVecIsa.INVALID
cpu_info_content = get_cpu_proc_info()
for isa in vec_isa_info.keys():
if isa in cpu_info_content and config.cpp.simdlen == vec_isa_info[isa]:
return isa
return _SupportedVecIsa.INVALID
def cpp_compile_command(input, output, include_pytorch=False):
if include_pytorch:
valid_isa = supported_vector_isa()
if include_pytorch or valid_isa:
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")]
libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"]
macros = _SupportedVecIsa.vec_macro(valid_isa)
if macros:
macros = f"-D{macros}"
else:
# Note - this is effectively a header only inclusion. Usage of some header files may result in
# symbol not found, if those header files require a library.
@ -159,17 +230,19 @@ def cpp_compile_command(input, output, include_pytorch=False):
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
lpaths = []
libs = ["gomp"]
macros = ""
ipaths = " ".join(["-I" + p for p in ipaths])
lpaths = " ".join(["-L" + p for p in lpaths])
libs = " ".join(["-l" + p for p in libs])
return re.sub(
r"[ \n]+",
" ",
f"""
{cpp_compiler()} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable
{ipaths} {lpaths} {libs}
{cpp_compiler()} {input} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable
{ipaths} {lpaths} {libs} {macros}
-march=native -O3 -ffast-math -fno-finite-math-only -fopenmp
-o{output} {input}
-o{output}
""",
).strip()

View File

@ -1,6 +1,7 @@
import contextlib
import dataclasses
import functools
from copy import deepcopy
from pathlib import Path
from typing import Dict, List
@ -9,8 +10,9 @@ import sympy
import torch
from torch._prims_common import is_float_dtype
from .. import codecache, config
from ..utils import sympy_product, sympy_symbol
from .. import codecache, config, ir, metrics
from ..codegen.wrapper import WrapperCodeGen
from ..utils import sympy_product, sympy_subs, sympy_symbol
from ..virtualized import ops, V
from .common import (
BracesBuffer,
@ -120,6 +122,13 @@ def float16_reduction_prefix(rtype):
return prefix
def parallel_num_threads():
threads = config.cpp.threads
if threads < 1:
threads = torch.get_num_threads()
return threads
@functools.lru_cache()
def cpp_prefix():
path = Path(__file__).parent / "cpp_prefix.h"
@ -151,6 +160,135 @@ class CppPrinter(ExprPrinter):
cexpr = CppPrinter().doprint
class CppVecOverrides(OpOverrides):
"""Map element-wise ops to aten vectorization C++"""
@staticmethod
def add(a, b):
return f"{a} + {b}"
@staticmethod
def sub(a, b):
return f"{a} - {b}"
@staticmethod
def mul(a, b):
return f"{a} * {b}"
@staticmethod
def div(a, b):
return f"{a} / {b}"
@staticmethod
def abs(x):
return f"{x}.abs()"
@staticmethod
def sin(x):
return f"{x}.sin()"
@staticmethod
def cos(x):
return f"{x}.cos()"
@staticmethod
def exp(x):
return f"{x}.exp()"
@staticmethod
def sqrt(x):
return f"{x}.sqrt()"
@staticmethod
def rsqrt(x):
return f"{x}.rsqrt()"
@staticmethod
def pow(a, b):
return f"{a}.pow({b})"
@staticmethod
def log(x):
return f"{x}.log()"
@staticmethod
def round(x):
return f"{x}.round()"
@staticmethod
def floor(x):
return f"{x}.floor()"
@staticmethod
def ceil(x):
return f"{x}.ceil()"
@staticmethod
def trunc(x):
return f"{x}.trunc()"
@staticmethod
def fmod(a, b):
return f"{a}.fmod({b})"
@staticmethod
def lgamma(x):
return f"{x}.lgamma()"
@staticmethod
def logical_and(a, b):
return f"{a} && {b}"
@staticmethod
def logical_or(a, b):
return f"{a} || {b}"
@staticmethod
def tanh(a):
return f"{a}.tanh()"
@staticmethod
def reciprocal(a):
return f"{a}.reciprocal()"
@staticmethod
def constant(val, dtype):
if val == float("inf"):
quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif val == float("-inf"):
quote = f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
elif val is True or val is False:
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({str(val).lower()})"
else:
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({repr(val)})"
return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({quote})"
@staticmethod
def relu(x):
return f"at::vec::clamp_min({x}, decltype({x})(0))"
@staticmethod
def sigmoid(x):
return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())"
@staticmethod
def neg(x):
return f"{x}.neg()"
@staticmethod
def floordiv(a, b):
# a and b are integer type
_t = f"decltype({a})"
quot = f"{a} / {b}"
rem = f"{a} % {b}"
return f"(({a} < {_t}(0)) != ({b} < {_t}(0)) ? ({rem} != {_t}(0) ? {quot} - {_t}(1) : {quot}) : {quot})"
@staticmethod
def truncdiv(a, b):
# a and b are integer type
return f"{a} / {b}"
class CppOverrides(OpOverrides):
"""Map element-wise ops to C++"""
@ -413,9 +551,7 @@ class CppKernel(Kernel):
return V.graph.sizevars.size_hint(sympy_product(self.call_ranges))
def codegen_loops(self, code, worksharing):
threads = config.cpp.threads
if threads < 1:
threads = torch.get_num_threads()
threads = parallel_num_threads()
loops = [LoopLevel(var, size) for var, size in zip(self.itervars, self.ranges)]
loops, reductions = LoopNest(loops[: self.reduction_depth]), LoopNest(
@ -427,7 +563,7 @@ class CppKernel(Kernel):
# TODO(jansel): detect stride-1 dimension and vectorize that
if reductions:
reductions.loops[-1].simd = True
else:
elif loops:
loops.loops[-1].simd = True
par_depth = 0
@ -509,6 +645,265 @@ class CppKernel(Kernel):
(self.loads, self.compute, self.stores, self.cse) = prior
class CppVecKernel(CppKernel):
overrides = CppVecOverrides
def __init__(self, args, num_threads):
super(CppVecKernel, self).__init__(args, num_threads)
self.simd_len = config.cpp.simdlen
metrics.generated_cpp_vec_kernel_count += 1
def is_single_step_var(self, var: sympy.Symbol, index: sympy.Expr):
replacement = {var: var + 1}
new_index = sympy_subs(index, replacement)
delta = sympy.simplify(new_index - index)
return delta == 1
def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr):
expanded_index = sympy.expand(index)
return not expanded_index.has(var)
def transform_index(self, index: sympy.Expr):
expanded_index = sympy.expand(index)
assert self.simd_len
assert self.simd_len > 0
most_inner_var = self.itervars[-1]
replacement = {most_inner_var: most_inner_var * self.simd_len}
new_index = sympy_subs(expanded_index, replacement)
return new_index
def load(self, name: str, index: sympy.Expr):
var = self.args.input(name)
index = self.rename_indexing(index)
expanded_index = sympy.expand(index)
new_index = self.transform_index(index)
if expanded_index == new_index:
line = f"at::vec::Vectorized<float>({var}[{cexpr(index)}])"
else:
line = f"at::vec::Vectorized<float>::loadu({var} + {cexpr(new_index)})"
return self.cse.generate(self.loads, line)
def store(self, name, index, value, mode=None):
assert "buf" in name
var = self.args.output(name)
index = self.rename_indexing(index)
assert mode is None
expanded_index = sympy.expand(index)
new_index = self.transform_index(index)
assert new_index != expanded_index
line = f"{value}.store({var} + {cexpr(new_index)});"
self.stores.writeline(name, line)
class CppVecKernelChecker(CppVecKernel):
def __init__(self, args, num_threads):
super(CppVecKernelChecker, self).__init__(args, num_threads)
# Since this kernel is only for checker but does not genreate any
# code, so we need to decrease the kernel count.
metrics.generated_kernel_count -= 1
metrics.generated_cpp_vec_kernel_count -= 1
# Used to recorde the graph wrapper code as the wrapper_code status could be
# changed during graph run.
self._orig_wrapper_code = None
self.simd_vec = True
self.fast_vec_list = []
for k, v in CppVecOverrides.__dict__.items():
if isinstance(v, staticmethod):
self.fast_vec_list.append(k)
self.exit_stack = contextlib.ExitStack()
def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr):
return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index)
def could_vec(self, name: str, index: sympy.Expr):
if V.graph.get_dtype(name) is not torch.float:
return False
assert self.itervars is not None
# Not a loop
if len(self.itervars) == 0:
return False
most_inner_var = self.itervars[-1]
return self.is_legal_data_access(most_inner_var, index)
def load(self, name: str, index: sympy.Expr):
index = self.rename_indexing(index)
self.simd_vec = self.simd_vec and self.could_vec(name, index)
return self.simd_vec
def store(self, name, index, value, mode=None):
assert "buf" in name
index = self.rename_indexing(index)
if mode:
self.simd_vec = False
return False
self.simd_vec = self.simd_vec and self.could_vec(name, index)
return self.simd_vec
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
self.simd_vec = False
return self.simd_vec
def __exit__(self, exc_type, exc_val, exc_tb):
assert self._orig_wrapper_code is not None
# Restore the wrapper_code
V.graph.wrapper_code = self._orig_wrapper_code
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
def __enter__(self):
# Recorde the graph wrapper code. The wrapper_code status could be
# changed during graph run. Regarding this checker, we also need to
# run the graph but we don't expect to change any status that would
# impact the code generation. Hence, we record the graph wapper code
# and replace it with a dummy warpper_code and then restore to the
# original one as long as the checker is finished.
self._orig_wrapper_code = V.graph.wrapper_code
V.graph.wrapper_code = WrapperCodeGen()
class VecCheckerProxy:
@staticmethod
def __getattr__(name):
def inner(*args, **kwargs):
if not (name in self.fast_vec_list):
self.simd_vec = False
return self.simd_vec
return inner
@staticmethod
def load(name: str, index: sympy.Expr):
return self.load(name, index)
@staticmethod
def store(name, index, value, mode=None):
return self.store(name, index, value, mode=mode)
@staticmethod
def reduction(name, dtype, src_dtype, reduction_type, index, value):
return self.reduction(
name, dtype, src_dtype, reduction_type, index, value
)
@staticmethod
def constant(val, dtype):
supported_dtype = (torch.float32, torch.int32)
is_supported_dtype = dtype in (supported_dtype)
if not is_supported_dtype:
self.simd_vec = False
return is_supported_dtype
@staticmethod
def index_expr(expr, dtype):
self.simd_vec = False
return self.cse.newvar()
@staticmethod
def indirect_indexing(index_var):
return sympy.Symbol(str(index_var))
@staticmethod
def masked(mask, body, other):
return V.kernel.cse.newvar()
self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy()))
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self
class CppKernelProxy(CppKernel):
def __init__(self, args=None, num_threads=None):
super(CppKernelProxy, self).__init__(args, num_threads)
self.simd_vec_kernel = None
self.simd_omp_kernel = None
def vectorize_most_inner_loop(self, loop_nest):
loop_nest.split_most_inner_loop(config.cpp.simdlen)
loop_with_tail = loop_nest.loops[-1]
assert isinstance(loop_with_tail, LoopLevelWithTail)
self.simd_vec_kernel.simd = False
self.simd_vec_kernel.fast_vec = True
loop_with_tail.tail_loop.simd_omp = True
# We chope the loop into two cubes by the config.cpp.simdlen - main loop and tail loop.
# Regarding the main loop, it is straightforward that it could be vectorized with
# config.cpp.simdlen. But for the tail loop, it still could be vectorized. For example,
# if the config.cpp.simdlen is 8(256bits), then the tail loop still could be vectorized
# as 4(128bits).
loop_with_tail.tail_loop.simd_len = int(config.cpp.simdlen / 2)
loop_with_tail.tail_loop.simd_vec = False
loop_with_tail.main_loop_body = self.simd_vec_kernel
loop_with_tail.tail_loop_body = self.simd_omp_kernel
return loop_nest
def codegen_loops(self, code, worksharing):
threads = parallel_num_threads()
if self.simd_vec_kernel is None:
assert self.simd_omp_kernel
return self.simd_omp_kernel.codegen_loops(code, worksharing)
assert self.simd_vec_kernel.itervars == self.simd_omp_kernel.itervars
assert self.simd_vec_kernel.ranges == self.simd_omp_kernel.ranges
itervars = self.simd_vec_kernel.itervars
rangs = self.simd_vec_kernel.ranges
loops = [LoopLevel(var, size) for var, size in zip(itervars, rangs)]
# TODO: Support reductions
loops_nest_non_reduc, _ = LoopNest(loops[: self.reduction_depth]), LoopNest(
loops[self.reduction_depth :]
)
assert config.cpp.simdlen
loops_nest_non_reduc.loops[-1].simd_omp = True
par_depth = 0
if loops_nest_non_reduc:
par_depth = self.simd_vec_kernel.decide_parallel_depth(
self.simd_vec_kernel.call_ranges[: self.reduction_depth], threads
)
with contextlib.ExitStack() as stack:
if par_depth:
worksharing.parallel(threads)
loops_nest_non_reduc.mark_parallel(par_depth)
elif threads > 1:
if worksharing.single():
stack.enter_context(code.indent())
self.vectorize_most_inner_loop(loops_nest_non_reduc)
for loop in loops_nest_non_reduc.loops[0:-1]:
code.writelines(loop.lines())
stack.enter_context(code.indent())
loop_with_tail: LoopLevelWithTail = loops_nest_non_reduc.loops[-1]
for loop, kernel in (
(loop_with_tail.main_loop, loop_with_tail.main_loop_body),
(loop_with_tail.tail_loop, loop_with_tail.tail_loop_body),
):
code.writelines(loop.lines())
with contextlib.ExitStack() as stack:
stack.enter_context(code.indent())
code.splice(kernel.loads)
code.splice(kernel.compute)
code.splice(kernel.stores)
class CppScheduling:
def __init__(self, scheduler):
self.scheduler = scheduler
@ -532,38 +927,113 @@ class CppScheduling:
def can_fuse_vertical(cls, node1, node2):
return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction()
def codegen_nodes(self, nodes):
"""
Turn an set of pre-fused nodes into a C++ kernel.
"""
kernel_group = self.kernel_group
scheduler = self.scheduler
def can_vec(self, nodes):
# TODO: Query cpu arch and vec length from aten
if not codecache.supported_vector_isa():
return False
_, (group, reduction_group) = max(
nodes, key=lambda x: int(x.is_reduction())
).group
in_suffix = False
with kernel_group.new_kernel() as kernel:
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
with CppVecKernelChecker(
deepcopy(self.kernel_group.args), parallel_num_threads()
) as kernel_checker:
vars, reduction_vars = kernel_checker.set_ranges(group, reduction_group)
for node in nodes:
if node.group[1] in [
(group, reduction_group),
(group + reduction_group, ()),
]:
assert not in_suffix
node.run(vars, reduction_vars)
else:
in_suffix = True
assert node.group[1] == (
group,
(),
), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
# we can fuse in some extra pointwise into the suffix
with kernel.write_to_suffix():
node.run(vars, ())
node.run(vars, ())
kernel_group.finalize_kernel(kernel, scheduler)
return kernel_checker.simd_vec
def _codegen_nodes_impl(self, nodes, is_simd_vec=False):
"""
Turn an set of pre-fused nodes into a C++ kernel.
"""
kernel_group = self.kernel_group
_, (group, reduction_group) = max(
nodes, key=lambda x: int(x.is_reduction())
).group
def create_kernel(_is_simd_vec):
in_suffix = False
with kernel_group.new_kernel(_is_simd_vec) as kernel:
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
for node in nodes:
if node.group[1] in [
(group, reduction_group),
(group + reduction_group, ()),
]:
assert not in_suffix
node.run(vars, reduction_vars)
else:
in_suffix = True
assert node.group[1] == (
group,
(),
), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
# we can fuse in some extra pointwise into the suffix
with kernel.write_to_suffix():
node.run(vars, ())
return kernel
org_inplace_buffers_flag = config.inplace_buffers
if is_simd_vec:
# Create vectorization kernel
cpp_vec_kernel = create_kernel(True)
# Since a kernel is divided into two parts - vectorization and non-vectorization.
# And the two parts share the same global contexts like V.graph.wrapper_code,
# V.kernel.args. But the vectorization kernel generation has updated these global
# contexts. Hence, the non-vectorization kernel should not do this again to avoid
# conext conflict. By now, we only control the config.inplace_buffers. In the future,
# we could maintain more contexts.
config.inplace_buffers = False
# Create non-vectorization kernel
cpp_kernel = create_kernel(False)
# Restore the inplace_buffers flag
config.inplace_buffers = org_inplace_buffers_flag
return (cpp_vec_kernel, cpp_kernel)
else:
return (None, create_kernel(False))
def codegen_nodes(self, nodes):
"""
Turn an set of pre-fused nodes into a C++ kernel.
"""
kernel_group = self.kernel_group
can_be_simd_vec = self.can_vec(nodes)
simd_vec_kernel, simd_omp_kernel = self._codegen_nodes_impl(
nodes, can_be_simd_vec
)
assert simd_omp_kernel
metrics.generated_kernel_count -= 1
# Maitain the metrics kernel count
if simd_vec_kernel:
metrics.generated_kernel_count -= 1
cpp_kernel_proxy = CppKernelProxy(
kernel_group.args, kernel_group.ws.num_threads
)
cpp_kernel_proxy.simd_vec_kernel = simd_vec_kernel
cpp_kernel_proxy.simd_omp_kernel = simd_omp_kernel
kernel_group.finalize_kernel(cpp_kernel_proxy, None)
def flush(self):
self.kernel_group.codegen_define_and_call(V.graph.wrapper_code)
@ -580,8 +1050,11 @@ class KernelGroup:
self.stack.enter_context(self.ws)
self.count = 0
def new_kernel(self):
return CppKernel(self.args, self.ws.num_threads)
def new_kernel(self, simd_vec=False):
if simd_vec:
return CppVecKernel(self.args, parallel_num_threads())
else:
return CppKernel(self.args, parallel_num_threads())
def finalize_kernel(self, new_kernel, scheduler):
self.count += 1
@ -660,10 +1133,14 @@ class WorkSharing:
@dataclasses.dataclass
class LoopLevel:
var: sympy.Expr
size: sympy.Expr
var: sympy.Expr = None
size: sympy.Expr = None
offset: sympy.Expr = sympy.Integer(0)
steps: sympy.Expr = sympy.Integer(1)
parallel: int = 0
simd: bool = False
simd_omp: bool = False
simd_len: int = config.cpp.simdlen
simd_vec: bool = False
collapsed: bool = False
reduction_vars: Dict[str, str] = None
@ -675,26 +1152,40 @@ class LoopLevel:
)
else:
reduction = ""
simd = f"simd simdlen({config.cpp.simdlen})"
simd = f"simd simdlen({self.simd_len})" if self.simd_omp else ""
if self.parallel:
# TODO(jansel): look into chunk size and other schedules
line1 = f"#pragma omp for{reduction} "
if self.parallel > 1:
line1 += f" collapse({self.parallel})"
if self.simd:
if self.simd_omp:
line1 = line1.replace(" for ", f" for {simd}")
elif self.simd:
elif self.simd_vec:
line1 = ""
elif self.simd_omp:
line1 = f"#pragma omp {simd}{reduction}"
elif not self.reduction_vars and codecache.is_gcc():
line1 = "#pragma GCC ivdep"
else:
line1 = ""
line2 = f"for({INDEX_TYPE} {self.var}=0; {self.var}<{cexpr(self.size)}; ++{self.var})"
line2 = f"for({INDEX_TYPE} {self.var}={cexpr(self.offset)}; {self.var}<{cexpr(self.size)}; {self.var}+={cexpr(self.steps)})"
if self.collapsed or not line1:
return [line2]
return [line1, line2]
class LoopLevelWithTail(LoopLevel):
def __init__(self, main_loop: LoopLevel, tail_loop: LoopLevel):
super().__init__()
self.main_loop = main_loop
self.tail_loop = tail_loop
self.main_loop_body = None
self.tail_loop_body = None
def lines(self):
raise AssertionError("Not Implemented")
@dataclasses.dataclass
class LoopNest:
loops: List[LoopLevel]
@ -711,7 +1202,35 @@ class LoopNest:
loops[0].parallel = par_depth
for i in range(1, par_depth):
loops[i].collapsed = True
loops[0].simd = loops[par_depth - 1].simd
def split_most_inner_loop(self, factor):
sympy_factor = sympy.Integer(factor)
most_inner_loop = self.loops[-1]
# If the most inner loop needs to be collapsed, we need to
# exclude it since we need to split it into two loops. Meanwhile,
# we still mark it as parallized.
if most_inner_loop.collapsed:
assert self.loops[0].parallel == len(self.loops)
self.loops[0].parallel -= 1
main_loop_range = ir.IndexingDiv(most_inner_loop.size, sympy_factor)
main_loop = LoopLevel(most_inner_loop.var, main_loop_range)
main_loop.parallel = 1 if most_inner_loop.parallel > 0 else 0
main_loop.collapsed = False
offset = main_loop_range * sympy_factor
tail_loop = LoopLevel(most_inner_loop.var, most_inner_loop.size)
tail_loop.offset = offset
tail_loop.parallel = 1 if most_inner_loop.parallel > 0 else 0
tail_loop.collapsed = False
loop_with_tail = LoopLevelWithTail(main_loop, tail_loop)
loop_with_tail.parallel = 0
loop_with_tail.collapsed = False
self.loops[-1] = loop_with_tail
def codegen(self, code, stack):
for loop in self.loops:

View File

@ -6,8 +6,11 @@
#include <omp.h>
#include <ATen/core/PhiloxRNGEngine.h>
#include <c10/util/Half.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
#include <ATen/cpu/vec/vec.h>
#endif
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
typedef at::Half half;
typedef at::BFloat16 bfloat16;

View File

@ -1,8 +1,12 @@
# counter for tracking how many kernels have been generated
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
# reset all counters
def reset():
global generated_kernel_count
global generated_cpp_vec_kernel_count
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0