[inductor] Rewrite Triton templates + epilogue fusion (retry) (#91575)

This reverts commit 94262efc7d381ace82aa74ed2f5f5ec76f8fca95 to reland #91105 / #90738.

Fixes https://github.com/pytorch/torchdynamo/issues/2015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91575
Approved by: https://github.com/ngimel
This commit is contained in:
Jason Ansel
2023-01-11 00:08:03 +00:00
committed by PyTorch MergeBot
parent 6912f7c564
commit 7c1c239db1
34 changed files with 1584 additions and 1956 deletions

View File

@ -139,7 +139,7 @@ def main():
if args.verbose:
torch._inductor.config.debug = True
torch._inductor.config.triton.autotune = True
torch._inductor.config.triton.autotune_pointwise = True
rows = []
for model in (MicroBenchmarks.sum,):

View File

@ -118,6 +118,10 @@ REQUIRE_EVEN_HIGHER_TOLERANCE = {
"tacotron2",
}
REQUIRE_HIGHER_FP16_TOLERANCE = {
"drq",
}
REQUIRE_COSINE_TOLERACE = {
# Just keeping it here even though its empty, if we need this in future.
}
@ -335,6 +339,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
cosine = self.args.cosine
# Increase the tolerance for torch allclose
if self.args.float16 or self.args.amp:
if name in REQUIRE_HIGHER_FP16_TOLERANCE:
return 1e-2, cosine
return 1e-3, cosine
if is_training and current_device == "cuda":
tolerance = 1e-3

View File

@ -1156,7 +1156,6 @@ def main():
'include/THH/generic/*.h',
'include/sleef.h',
"_inductor/codegen/*.h",
"_inductor/codegen/*.j2",
'share/cmake/ATen/*.cmake',
'share/cmake/Caffe2/*.cmake',
'share/cmake/Caffe2/public/*.cmake',

View File

@ -0,0 +1,146 @@
# Owner(s): ["module: inductor"]
import functools
import logging
from unittest.mock import patch
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
import torch.nn.functional as F
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
torch.backends.cuda.matmul.allow_tf32 = False
def patches(fn):
def skip_cache(self, key, generate):
return generate()
for patcher in [
patch.object(dynamo_config, "log_level", logging.INFO),
patch.object(dynamo_config, "verbose", True),
patch.object(inductor_config, "debug", True),
patch.object(inductor_config, "max_autotune", True),
patch.object(inductor_config, "epilogue_fusion", True),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
]:
fn = patcher(fn)
@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
assert (
not torch.backends.cuda.matmul.allow_tf32
), "correctness testing is allergic to tf32"
return fn(*args, **kwargs)
return wrapped
class TestSelectAlgorithm(TestCase):
@patches
def test_linear_relu(self):
@torch.compile
def foo(input, weight, bias):
return F.relu(F.linear(input, weight, bias))
foo(
torch.randn(64, 32, device="cuda"),
torch.randn(16, 32, device="cuda"),
torch.randn(16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# It would be nice to assert this got fused into a single kernel, but that
# only happens if we select a triton template (and not aten).
@patches
def test_addmm(self):
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)
foo(
torch.randn(20, 33, device="cuda"),
torch.randn(33, 16, device="cuda"),
torch.randn(20, 16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device="cuda"),
torch.randn(32, 8, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_skip(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(8, 32, device="cuda", dtype=torch.float64),
torch.randn(32, 8, device="cuda", dtype=torch.float64),
)
# float64 not supported by tl.dot()
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
@patches
def test_bmm(self):
@torch.compile
def foo(a, b):
return torch.bmm(a, b)
foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_not_even_k(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)
foo(
torch.randn(11, 22, device="cuda"),
torch.randn(22, 33, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_baddbmm(self):
@torch.compile
def foo(a, b, c):
return torch.baddbmm(c, a, b)
foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
torch.randn(2, 1, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
run_tests()

View File

@ -79,7 +79,7 @@ requires_multigpu = functools.partial(
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
)
torch._inductor.config.triton.autotune = False # too slow
torch._inductor.config.triton.autotune_pointwise = False # too slow
# For OneDNN bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw,
@ -2505,76 +2505,6 @@ class CommonTemplate:
self.assertEqual(a.stride(), c.stride())
self.assertEqual(c.stride()[2], 1)
@requires_cuda()
@patch.object(config.triton, "convolution", "triton")
@patch.object(config.triton, "dense_indexing", "True")
def test_triton_conv(self):
@torch._dynamo.optimize("inductor", nopython=True)
def triton_conv(
x,
w,
bias,
stride,
padding,
dilation,
groups,
):
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
return y
stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
dtype = torch.float32
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
bias = torch.randn((32), dtype=dtype, device=self.device)
y = triton_conv(x, w, bias, stride, padding, dilation, groups)
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
@requires_cuda()
@patch.object(config.triton, "convolution", "autotune")
@patch.object(config.triton, "dense_indexing", "True")
def test_conv_autotune(self):
@torch._dynamo.optimize("inductor", nopython=True)
def triton_conv(
x,
w,
bias,
stride,
padding,
dilation,
groups,
):
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
return y
stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
dtype = torch.float32
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
bias = torch.randn((32), dtype=dtype, device=self.device)
y = triton_conv(x, w, bias, stride, padding, dilation, groups)
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
@patch.object(config.triton, "mm", "triton")
def test_triton_mm2(self):
@torch._dynamo.optimize("inductor", nopython=True)
def fn(x, y):
return torch.relu(torch.mm(x, y))
N = 1024
a = torch.randn([N, N], device=self.device, dtype=torch.float32)
b = torch.randn([N, N], device=self.device, dtype=torch.float32)
c1 = torch.relu(torch.mm(a, b))
torch._inductor.metrics.reset()
c = fn(a, b)
assert torch.allclose(c1, c, atol=1e-3, rtol=1e-3)
if self.device == "cuda":
assert torch._inductor.metrics.generated_kernel_count == 1
def test_std(self):
def fn(x):
return (
@ -4560,12 +4490,6 @@ class CommonTemplate:
)
expected_kernel = 0
# codegen mm kernel from template
if config.triton.mm != "aten" and self.device == "cuda":
expected_kernel = 1
if config.triton.mm == "autotune":
self.assertLessEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
@ -4641,15 +4565,6 @@ class CommonTemplate:
result.sum().backward()
expected_kernel = 4
if config.triton.mm != "aten" and self.device == "cuda":
# fwd: 2 * (mm+dropout) kernels = 2 kernels
# bwd: dropout + (mm) + 2 * (mm+dropout) kernels = 4 kernels
# expect 2 + 4 = 6 kernels
expected_kernel = 6
if config.triton.mm == "autotune":
self.assertLessEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
@ -4979,7 +4894,6 @@ class CommonTemplate:
inputs = (inputs[1], inputs[0])
self.assertTrue(same(opt(*inputs), fn(*inputs)))
@patch.object(config.triton, "mm", "aten")
def test_list_clearing(self):
if self.device == "cpu":
@ -5685,7 +5599,7 @@ if HAS_CUDA:
res = opt_mod(*args)
self.assertTrue(same(ref, res))
@patch.object(config.triton, "autotune", True)
@patch.object(config.triton, "autotune_pointwise", True)
def test_inplace_add_alpha_autotune(self):
def fn(x, y):
aten.add_.Tensor(x, y, alpha=0.55)
@ -5703,7 +5617,7 @@ if HAS_CUDA:
fn_compiled([x3, y])
assert same(x2, x3)
@patch.object(config.triton, "autotune", True)
@patch.object(config.triton, "autotune_pointwise", True)
def test_inplace_buffer_autotune(self):
def foo(x, y, z):
a = x @ y

View File

@ -242,8 +242,12 @@ def requires_static_shapes(fn):
return _fn
def rand_strided(size, stride, dtype=torch.float32, device="cpu"):
needed_size = sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
def rand_strided(size, stride, dtype=torch.float32, device="cpu", extra_size=0):
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(size, stride))
+ 1
+ extra_size
)
if dtype.is_floating_point:
buffer = torch.randn(needed_size, dtype=dtype, device=device)
else:

View File

@ -3,6 +3,7 @@ import dataclasses
import functools
import getpass
import hashlib
import json
import logging
import multiprocessing
import os
@ -60,6 +61,36 @@ def cache_dir():
)
class DiskCache:
@staticmethod
@functools.lru_cache(None)
def _subdir():
subdir = os.path.join(cache_dir(), "cached_tunings")
os.makedirs(subdir, exist_ok=True)
return subdir
@staticmethod
@functools.lru_cache(4096)
def _read_file(path):
with open(path, "r") as fd:
return json.loads(fd.read())
def __init__(self, unique_name):
super().__init__()
self.unique_name = unique_name
def lookup(self, key: Any, generate: Callable[[], Any]):
"""
Check if we have already generated key, if not call generate()
to populate the cache.
"""
path = os.path.join(self._subdir(), code_hash(self.unique_name + repr(key)))
if not os.path.exists(path):
value = generate()
write_atomic(path, json.dumps(value))
return self._read_file(path)
def get_lock_dir():
lock_dir = os.path.join(cache_dir(), "locks")
if not os.path.exists(lock_dir):
@ -88,14 +119,18 @@ def write(source_code, ext, extra=""):
if not os.path.exists(subdir):
os.makedirs(subdir, exist_ok=True)
if not os.path.exists(path):
# use a temp file for thread safety
fd, tmp_path = tempfile.mkstemp(dir=subdir)
with os.fdopen(fd, "w") as f:
f.write(source_code)
os.rename(tmp_path, path)
write_atomic(path, source_code)
return basename, path
def write_atomic(path: str, source_code: str):
# use a temp file for thread safety
fd, tmp_path = tempfile.mkstemp(dir=os.path.dirname(path))
with os.fdopen(fd, "w") as f:
f.write(source_code)
os.rename(tmp_path, path)
def cpp_compiler():
if isinstance(config.cpp.cxx, (list, tuple)):
search = tuple(config.cpp.cxx)

View File

@ -3,7 +3,6 @@ import builtins
import torch
from .. import config, triton_ops
from ..triton_ops.autotune import mm_autotune, mm_heuristics
from ..utils import dynamo_testing
from ..virtualized import V
@ -141,79 +140,6 @@ def tuned_conv(
return best_kernel
def tuned_mm(
a_shape,
b_shape,
a_stride,
b_stride,
device,
dtype,
adjust_triton=0.95,
):
"""
Return the best kernel name given mm input size;
Considering potential pointwise fusion of mm, we could adjust triton timing
by multiplying adjust_triton (default=0.95)
"""
sizevars = V.graph.sizevars
a_shape = [sizevars.size_hint(s) for s in a_shape]
b_shape = [sizevars.size_hint(s) for s in b_shape]
a_stride = [sizevars.size_hint(s) for s in a_stride]
b_stride = [sizevars.size_hint(s) for s in b_stride]
a = rand_strided(a_shape, a_stride, device=device, dtype=dtype)
b = rand_strided(b_shape, b_stride, device=device, dtype=dtype)
c = torch.empty((a_shape[0], b_shape[1]), device=device, dtype=dtype)
id_args = [
*a_shape,
*b_shape,
]
use_cuda = a.is_cuda
# gen_key
key = tuple(id_args)
key = ("mm",) + key
# candidate kernels
kernels = ["aten.mm.out"]
if use_cuda:
kernels += ["triton_ops.matmul_out"]
# if only one choice, return that kernel
if len(kernels) == 1:
kernel = kernels[0]
return kernel
timings = {}
if key not in autotune.cache:
# bench_start = time.time()
for kernel in kernels:
runnable_kernel = str2func(kernel)
if "triton_ops" in kernel:
run_args = (a, b, c)
run_kwargs = {}
inner_kernel = str2func(
kernel.replace("matmul_out", "_matmul_out") + ".kernel"
)
inner_kernel.kernel_decorators = []
# fix SPLIT_K = 1 for fusable kernels
mm_heuristics()(mm_autotune(get_io_bound_configs=False)(inner_kernel))
else:
run_args = (a, b)
run_kwargs = {"out": c}
timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
if "triton_ops" in kernel:
timing = timing * adjust_triton
timings[kernel] = timing
# bench_end = time.time()
# bench_time = bench_end - bench_start
autotune.cache[key] = builtins.min(timings, key=timings.get)
if config.debug:
print("for key = ", key)
print("timing", timings)
print("best_kernel", autotune.cache[key])
best_kernel = autotune.cache[key]
return best_kernel
def tuned_conv_layout(
kernel,
x_shape,

View File

@ -1,4 +1,3 @@
import collections
import contextlib
import itertools
import logging
@ -200,13 +199,29 @@ class KernelArgs:
return odict[name]
def __init__(self, sizevars=None):
self.input_buffers = collections.OrderedDict()
self.output_buffers = collections.OrderedDict()
self.inplace_buffers = collections.OrderedDict()
self.sizevars = sizevars or collections.OrderedDict()
self.input_buffers = dict()
self.output_buffers = dict()
self.inplace_buffers = dict()
self.sizevars = sizevars or dict()
def __repr__(self):
return "KernelArgs({})".format(
", ".join(
map(
repr,
[
self.input_buffers,
self.output_buffers,
self.inplace_buffers,
self.sizevars,
],
)
)
)
def input(self, name):
name = V.graph.scheduler.mutation_real_name.get(name, name)
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.output_buffers:
return self.output_buffers[name]
@ -217,7 +232,8 @@ class KernelArgs:
return self._lookup("in_ptr", self.input_buffers, name)
def output(self, name):
name = V.graph.scheduler.mutation_real_name.get(name, name)
if V.graph.scheduler:
name = V.graph.scheduler.mutation_real_name.get(name, name)
assert name not in V.graph.removed_buffers, name
if name in self.inplace_buffers:
return self.inplace_buffers[name].inner_name
@ -428,7 +444,10 @@ class CSE:
var = self.newvar()
self.cache[expr] = var
if write:
V.kernel.current_node.codegen_originating_info(buffer, only_once=True)
if V.kernel.current_node:
V.kernel.current_node.codegen_originating_info(
buffer, only_once=True
)
buffer.writeline(f"{self.prefix}{var} = {expr}{self.suffix}")
return self.cache[expr]
@ -552,8 +571,9 @@ class Kernel(CodeGen):
self.store_buffer_names.add(name)
if mode is None:
self.cse.store_cache[name] = value
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if self.current_node:
for other_name in self.current_node.get_mutations():
self.cse.store_cache[other_name] = value
if name not in V.graph.removed_buffers:
return self.store(name, index, value, mode=mode)
@ -571,7 +591,8 @@ class Kernel(CodeGen):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
V.graph.scheduler.remove_kernel_local_buffers()
if V.graph.scheduler:
V.graph.scheduler.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb)
def rename_indexing(self, index) -> sympy.Expr:

View File

@ -416,14 +416,17 @@ class IterationRangesRoot(IterationRanges):
self.nodes[expr] = node
return self.nodes[expr]
def construct(self, lengths: List[sympy.Expr]):
def construct_entries(self, lengths: List[sympy.Expr]):
divisor = sympy.Integer(1)
itervars = []
for length in reversed(lengths):
itervars.append(self.lookup(divisor, length).symbol())
itervars.append(self.lookup(divisor, length))
divisor = divisor * length
return list(reversed(itervars))
def construct(self, lengths: List[sympy.Expr]):
return [e.symbol() for e in self.construct_entries(lengths)]
def vars_and_sizes(self, index: sympy.Expr):
"""Figure out vars from this tree used in index"""
nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
@ -497,6 +500,11 @@ class IterationRangesEntry(IterationRanges):
self.codegen = functools.lru_cache(None)(self._codegen)
self.expr = expr
def set_name(self, name):
self.codegen = lambda: name
self.codegen.cache_clear = lambda: None
self.name = name
def cache_clear(self):
self.codegen.cache_clear()
@ -732,6 +740,7 @@ class TritonKernel(Kernel):
*,
copy_shape=None,
dense_indexing=False,
override_mask=None,
):
"""
Compute the index and mask to pass to tl.load() or tl.store()
@ -742,7 +751,9 @@ class TritonKernel(Kernel):
mask_vars: Set[str] = set()
for var in index_vars:
if var.name.startswith("tmp"):
if override_mask:
pass
elif var.name.startswith("tmp"):
# indirect indexing
cse_var = self.cse.varname_map[var.name]
mask_vars.update(cse_var.mask_vars)
@ -770,7 +781,10 @@ class TritonKernel(Kernel):
dense_mask_vars.add(f"{tree.prefix}mask")
if (need_dense and not have_dense) or isinstance(index, sympy.Integer):
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
if copy_shape:
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
else:
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
if isinstance(index, sympy.Integer):
return index_str, set(), "None"
else:
@ -779,6 +793,9 @@ class TritonKernel(Kernel):
mask_vars = dense_mask_vars
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
if override_mask:
mask_vars = {override_mask}
if self._load_mask:
mask_vars.add(self._load_mask)
@ -829,6 +846,7 @@ class TritonKernel(Kernel):
ep = ", eviction_policy='evict_last'"
else:
ep = ""
# "other" below is a workaround for https://github.com/openai/triton/issues/737
# for bool, even though it's likely subject to the same bug, setting `other` leads
# to LLVM errors so we are skipping it for now
@ -1106,6 +1124,13 @@ class TritonKernel(Kernel):
wrapper.writeline("''')")
return wrapper.getvalue()
def codegen_template_wrapper(self, src_code):
wrapper = IndentedBuffer()
wrapper.writeline("async_compile.triton('''")
wrapper.splice(src_code, strip=True)
wrapper.writeline("''')")
return wrapper.getvalue()
def codegen_static_numels(self, code):
"""
We get a small speedup from hard coding numels if they are static.
@ -1187,6 +1212,9 @@ class TritonScheduling:
if not (numel1 == numel2 and rnumel1 == rnumel2):
return False
if node1.is_template():
return True # skip checks for compatible tiling
# check for a bad combined tiling
tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
@ -1212,7 +1240,10 @@ class TritonScheduling:
for n in node1.get_nodes()
):
return False
if config.triton.tiling_prevents_reduction_fusion:
if (
config.triton.tiling_prevents_reduction_fusion
and not node1.is_template()
):
return self.select_tiling(node1.get_nodes(), numel1) in (
(numel1, 1),
(numel2, rnumel2, 1),
@ -1348,8 +1379,13 @@ class TritonScheduling:
index_vars = kernel.split_and_set_ranges(node.get_ranges())
node.codegen(index_vars)
wrapper = V.graph.wrapper_code
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule)
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
self.scheduler.free_buffers()
def define_kernel(self, src_code, node_schedule):
wrapper = V.graph.wrapper_code
if src_code in wrapper.kernels:
kernel_name = wrapper.kernels[src_code]
else:
@ -1366,7 +1402,25 @@ class TritonScheduling:
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
src_code = src_code.replace("#pragma CMT", "#")
wrapper.define_kernel(kernel_name, src_code)
kernel.call_kernel(wrapper, kernel_name)
return kernel_name
def codegen_template(self, template_node, epilogue_nodes):
"""
Codegen a triton template
"""
_, (numel, rnumel) = template_node.group
assert rnumel == 1
kernel, render = template_node.node.make_kernel_render(template_node.node)
with kernel:
for node in [template_node, *epilogue_nodes]:
node.mark_run()
render() # warmup run to get the args right
for node in epilogue_nodes:
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
src_code = kernel.codegen_template_wrapper(render())
kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
self.scheduler.free_buffers()
def codegen_sync(self):

View File

@ -1,181 +0,0 @@
@conv_heuristics()
@triton.jit
def {{kernel_name}}(
{% for i in template_inout_argdefs %}
{{i}},
{% endfor %}
# stride of tensor
stride_xn,
stride_xc,
stride_xh,
stride_xw,
stride_wn,
stride_wc,
stride_wh,
stride_ww,
stride_yn,
stride_yc,
stride_yh,
stride_yw,
stride_biasn,
# Tensor dimensions
BATCH,
IN_C,
IN_H,
IN_W,
KERNEL_N,
KERNEL_H,
KERNEL_W,
OUT_H,
OUT_W,
# parameters of conv
stride_h,
stride_w,
padding_h,
padding_w,
dilation_h,
dilation_w,
output_padding_h,
output_padding_w,
groups: tl.constexpr,
# pointer inc for x
delta_x_ptr,
# fusable kernels args
{% for i in extra_argdefs %}
{{i}},
{% endfor %}
# Metaparameters
ACC_TYPE: tl.constexpr,
CONV1X1_NHWC: tl.constexpr,
# blocks in different dimension
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
# reduction tiling parameter for matmul
BLOCK_K: tl.constexpr,
):
"""
each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of y it should compute.
pid_nhw = tl.program_id(0)
pid_k = tl.program_id(1)
# offset for output y
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
off_y_n = off_y_nhw // (OUT_H * OUT_W)
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
off_y_h = off_y_hw // OUT_W
off_y_w = off_y_hw % OUT_W
# offset for the initial ptr for x
off_x_n = off_y_n
off_x_h = off_y_h * stride_h - padding_h
off_x_w = off_y_w * stride_w - padding_w
off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
off_x_crs = tl.arange(0, BLOCK_K)
CRS = IN_C * KERNEL_H * KERNEL_W
# load inc ptr of x, upade x_ptrs
if not CONV1X1_NHWC:
delta_x_ptrs = delta_x_ptr + off_x_crs
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS, other=0)
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
else:
x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
mask_x = (
(off_x_n < BATCH)
& (off_x_h >= 0)
& (off_x_h < IN_H)
& (off_x_w >= 0)
& (off_x_w < IN_W)
)[:, None] & (off_x_crs < CRS)[None, :]
# offset for the inital ptr for w
off_w_crs = tl.arange(0, BLOCK_K)
off_w_k = off_y_k
w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
# ------ load x ------
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
# ------ load w ------
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
# -----------------------------------------------------------
# allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K):
# ------ matrix multiplication ------
acc += tl.dot(matrix_x, matrix_w)
# ------ update ptrs ------
w_ptrs += BLOCK_K
# load inc ptr of x, upade x_ptrs
if not CONV1X1_NHWC:
delta_x_ptrs += BLOCK_K
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS, other=0)
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
else:
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
x_ptrs += BLOCK_K
mask_x = (
(off_x_n < BATCH)
& (off_x_h >= 0)
& (off_x_h < IN_H)
& (off_x_w >= 0)
& (off_x_w < IN_W)
)[:, None] & (off_x_crs < CRS)[None, :]
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
# ------ prefetch ------
# ------ load x ------
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
# ------ load w ------
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
acc = acc.to({{out_def}}.dtype.element_ty)
{% if keep_store %}
# rematerialize -- this saves some registers
# offset for output y
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
off_y_n = off_y_nhw // (OUT_H * OUT_W)
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
# consider output padding
off_y_h = off_y_hw // OUT_W + output_padding_h
off_y_w = off_y_hw % OUT_W + output_padding_w
# y ptrs in the block of [BLOCK_M, BLOCK_N]
y_ptrs = (
{{out_def}}
+ off_y_n[:, None] * stride_yn
+ off_y_h[:, None] * stride_yh
+ off_y_w[:, None] * stride_yw
+ off_y_k[None, :] * stride_yc
)
# out-of-bounds check
mask_y = (
(off_y_n < BATCH)[:, None]
& (off_y_h < OUT_H + output_padding_h)[:, None]
& (off_y_w < OUT_W + output_padding_w)[:, None]
& (off_y_k < KERNEL_N)[None, :]
)
tl.store(y_ptrs, acc, mask=mask_y)
{% endif %}
{% if pointwise_code %}
{{ pointwise_code | indent(4, true) }}
{#
z = tl.load(z_ptrs, mask=mask_z)
acc += z
#}
{% endif %}
return

View File

@ -1,200 +0,0 @@
@conv_heuristics()
@triton.jit
def {{kernel_name}}(
{% for i in template_inout_argdefs %}
{{i}},
{% endfor %}
# stride of tensor
stride_xn,
stride_xc,
stride_xh,
stride_xw,
stride_wn,
stride_wc,
stride_wh,
stride_ww,
stride_yn,
stride_yc,
stride_yh,
stride_yw,
stride_biasn,
# Tensor dimensions
BATCH,
IN_C,
IN_H,
IN_W,
KERNEL_N,
KERNEL_H,
KERNEL_W,
OUT_H,
OUT_W,
# parameters of conv
stride_h,
stride_w,
padding_h,
padding_w,
dilation_h,
dilation_w,
output_padding_h,
output_padding_w,
groups,
# pointer inc for x
delta_xh_ptr,
delta_xw_ptr,
delta_xc_ptr,
# fusable kernels args
{% for i in extra_argdefs %}
{{i}},
{% endfor %}
# Metaparameters
ACC_TYPE: tl.constexpr,
CONV1X1_NHWC: tl.constexpr,
# blocks in different dimension
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
# reduction tiling parameter for matmul
BLOCK_K: tl.constexpr,
):
"""
each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of y it should compute.
pid_nhw = tl.program_id(0)
pid_k = tl.program_id(1)
# offset for output y
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
off_y_n = off_y_nhw // (OUT_H * OUT_W)
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
off_y_h = off_y_hw // OUT_W + output_padding_h
off_y_w = off_y_hw % OUT_W + output_padding_w
# offset for the initial ptr for x
off_x_n = off_y_n
off_x_h = off_y_h * stride_h - padding_h
off_x_w = off_y_w * stride_w - padding_w
off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
off_x_crs = tl.arange(0, BLOCK_K)
CRS = IN_C * KERNEL_H * KERNEL_W
# load inc ptr of x, upade x_ptrs
if not CONV1X1_NHWC:
delta_xh_ptrs = delta_xh_ptr + off_x_crs
delta_xw_ptrs = delta_xw_ptr + off_x_crs
delta_xc_ptrs = delta_xc_ptr + off_x_crs
delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
off_x_crs_unpacked = (
delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
)
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
else:
x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
delta_xh = 0
delta_xw = 0
mask_x = (
(off_x_n < BATCH)[:, None]
& (off_x_crs < CRS)[None, :]
& (off_x_h[:, None] + delta_xh[None, :] >= 0)
& (off_x_h[:, None] + delta_xh[None, :] < IN_H)
& (off_x_w[:, None] + delta_xw[None, :] >= 0)
& (off_x_w[:, None] + delta_xw[None, :] < IN_W)
)
# offset for the inital ptr for w
off_w_crs = tl.arange(0, BLOCK_K)
off_w_k = off_y_k
w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
# ------ load x ------
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
# ------ load w ------
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
# -----------------------------------------------------------
# allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K):
# ------ matrix multiplication ------
acc += tl.dot(matrix_x, matrix_w)
# ------ update ptrs ------
w_ptrs += BLOCK_K
# load inc ptr of x, upade x_ptrs
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
if not CONV1X1_NHWC:
delta_xh_ptrs += BLOCK_K
delta_xw_ptrs += BLOCK_K
delta_xc_ptrs += BLOCK_K
delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
off_x_crs_unpacked = (
delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
)
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
else:
x_ptrs += BLOCK_K
mask_x = (
(off_x_n < BATCH)[:, None]
& (off_x_crs < CRS)[None, :]
& (off_x_h[:, None] + delta_xh[None, :] >= 0)
& (off_x_h[:, None] + delta_xh[None, :] < IN_H)
& (off_x_w[:, None] + delta_xw[None, :] >= 0)
& (off_x_w[:, None] + delta_xw[None, :] < IN_W)
)
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
# ------ prefetch ------
# ------ load x ------
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
# ------ load w ------
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
acc = acc.to({{out_def}}.dtype.element_ty)
{% if keep_store %}
# rematerialize -- this saves some registers
# offset for output y
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
off_y_n = off_y_nhw // (OUT_H * OUT_W)
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
# consider output padding
off_y_h = off_y_hw // OUT_W + output_padding_h
off_y_w = off_y_hw % OUT_W + output_padding_w
# y ptrs in the block of [BLOCK_M, BLOCK_N]
y_ptrs = (
{{out_def}}
+ off_y_n[:, None] * stride_yn
+ off_y_h[:, None] * stride_yh
+ off_y_w[:, None] * stride_yw
+ off_y_k[None, :] * stride_yc
)
# out-of-bounds check
mask_y = (
(off_y_n < BATCH)[:, None]
& (off_y_h < OUT_H + output_padding_h)[:, None]
& (off_y_w < OUT_W + output_padding_w)[:, None]
& (off_y_k < KERNEL_N)[None, :]
)
tl.store(y_ptrs, acc, mask=mask_y)
{% endif %}
{% if pointwise_code %}
{{ pointwise_code | indent(4, true) }}
{#
z = tl.load(z_ptrs, mask=mask_z)
acc += z
#}
{% endif %}
return

View File

@ -1,80 +0,0 @@
import torch
import triton
import triton.language as tl
{# from triton.ops.matmul import get_configs_io_bound #}
@mm_autotune()
@mm_heuristics()
@triton.jit
def {{kernel_name}}(
{% for i in template_inout_argdefs %}
{{i}},
{% endfor %}
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# fusable kernels args
{% for i in extra_argdefs %}
{{i}},
{% endfor %}
allow_tf32: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A_ptrs = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B_ptrs = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K:
a = tl.load(A_ptrs)
b = tl.load(B_ptrs)
else:
a = tl.load(A_ptrs, mask=rk[None, :] < k, other=0.0)
b = tl.load(B_ptrs, mask=rk[:, None] < k, other=0.0)
acc += tl.dot(a, b, allow_tf32=allow_tf32)
A_ptrs += BLOCK_K * SPLIT_K * stride_ak
B_ptrs += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to({{out_def}}.dtype.element_ty)
{% if keep_store %}
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C_ptrs = {{out_def}} + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
tl.store(C_ptrs, acc, mask=mask)
{% endif %}
{% if pointwise_code %}
{{ pointwise_code | indent(4, true) }}
{% endif %}

View File

@ -1,351 +0,0 @@
import logging
import os
import sympy
from .. import config, ir
from ..virtualized import V
from .common import IndentedBuffer
from .triton import TritonKernel
log = logging.getLogger((__name__))
template_dict = {ir.Convolution: "triton_conv", ir.MatrixMultiply: "triton_mm"}
class TritonTemplateKernel(TritonKernel):
def __init__(self, node: ir.ExternKernel, *groups):
from jinja2 import Environment, FileSystemLoader, StrictUndefined
self.node = node
self.template_name = template_dict[type(node)]
env = Environment(
loader=FileSystemLoader(os.path.dirname(__file__)),
trim_blocks=True,
lstrip_blocks=True,
undefined=StrictUndefined,
)
pid_cache = {}
if isinstance(node, ir.Convolution):
pid_cache = {
"tl.program_id(0)": "pid_nhw",
"tl.program_id(1)": "pid_k",
}
self.map_args()
KERNEL_H = self.args_dict["KERNEL_H"]
KERNEL_W = self.args_dict["KERNEL_W"]
padding_h = self.args_dict["padding_h"]
padding_w = self.args_dict["padding_w"]
if ((KERNEL_H == "1" and KERNEL_W == "1")) or (
(padding_h == "0") and (padding_w == "0")
):
self.template_name += "_delta_x"
else:
self.template_name += "_delta_x_hwc"
elif isinstance(node, ir.MatrixMultiply):
pid_cache = {
"tl.program_id(0)": "pid_m",
"tl.program_id(1)": "pid_n",
}
self.template = env.get_template(self.template_name + ".j2")
super(TritonTemplateKernel, self).__init__(*groups, pid_cache=pid_cache)
def rename_vars(self):
for k, v in self.inout_dict.items():
self.args.output_buffers[v] = k
if isinstance(self.node, ir.Convolution):
self.cse.store_cache[self.inout_dict["y"]] = "acc"
elif isinstance(self.node, ir.MatrixMultiply):
self.cse.store_cache[self.inout_dict["C"]] = "acc"
def assign_block_numel(self):
code = IndentedBuffer()
if isinstance(self.node, ir.Convolution):
code.writeline("XBLOCK: tl.constexpr = BLOCK_M")
code.writeline("YBLOCK: tl.constexpr = BLOCK_N")
code.writeline(
"xnumel = BATCH * (OUT_H + 2 * output_padding_h) * (OUT_W + 2 * output_padding_w)"
)
code.writeline("ynumel = KERNEL_N")
elif isinstance(self.node, ir.MatrixMultiply):
code.writeline("XBLOCK: tl.constexpr = BLOCK_M")
code.writeline("YBLOCK: tl.constexpr = BLOCK_N")
code.writeline("xnumel = M")
code.writeline("ynumel = N")
return code
def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=True):
# use dense_indexing for TritonTemplateKernel to avoid map::at error
return super().indexing(
index, copy_shape=copy_shape, dense_indexing=dense_indexing
)
def codegen_body(
self, name, fuse, could_remove_kernel_buf, kernel_buf_replace_name
):
"""
put render_variables into the template
to generate the final code
"""
# get extra_argdefs from fusable triton kernels
self.extra_argdefs = []
self.extra_call_args = []
argdefs, call_args, _ = self.args.python_argdefs()
# add extra args if it is different from
# current TritonTemplateKernel args
for (argdef, call_arg) in zip(argdefs, call_args):
if (
argdef not in self.inout_dict.keys()
and argdef not in self.args_dict.keys()
):
self.extra_argdefs.append(argdef)
self.extra_call_args.append(call_arg)
if could_remove_kernel_buf:
if isinstance(self.node, ir.Convolution):
self.inout_dict.pop("y")
elif isinstance(self.node, ir.MatrixMultiply):
self.inout_dict.pop("C")
self.template_inout_argdefs = list(self.inout_dict.keys())
if kernel_buf_replace_name is not None:
idx = self.extra_call_args.index(kernel_buf_replace_name)
kernel_buf_replace_def = self.extra_argdefs[idx]
super().codegen_body()
self.pointwise_code = IndentedBuffer()
self.pointwise_code.splice(self.assign_block_numel())
self.pointwise_code.splice(self.body)
render_dict = {}
render_dict["kernel_name"] = name
render_dict["template_inout_argdefs"] = self.template_inout_argdefs
render_dict["extra_argdefs"] = self.extra_argdefs
render_dict["pointwise_code"] = self.pointwise_code.getvalue() if fuse else None
render_dict["keep_store"] = not could_remove_kernel_buf
render_dict["out_def"] = (
self.out_def() if not could_remove_kernel_buf else kernel_buf_replace_def
)
self.body = self.template.render(render_dict) + "\n"
def out_def(self):
if isinstance(self.node, ir.Convolution):
return "y"
elif isinstance(self.node, ir.MatrixMultiply):
return "C"
def codegen_kernel(
self,
name=None,
fuse=False,
could_remove_kernel_buf=False,
kernel_buf_replace_name=None,
):
code = IndentedBuffer()
self.codegen_body(name, fuse, could_remove_kernel_buf, kernel_buf_replace_name)
code.splice(self.body)
if name is not None:
return code.getvalue()
wrapper = IndentedBuffer()
wrapper.writeline("TritonCodeCache.load('''")
wrapper.splice(code.getvalue(), strip=True)
wrapper.writeline("''').kernel")
return wrapper.getvalue()
def map_args(self):
"""
map the constant args or
kernel[grid](..., IN_C, IN_H, IN_W, strides,...)
"""
(
self.inout_dict,
self.args_dict,
self.const_dict,
self.other_dict,
) = self.node.map_args()
def precompute(self, wrapper, kernel_name):
"""
some triton kernels needs host precompute tensor
for example, triton_conv needs precompte delta_x_ptr
"""
if isinstance(self.node, ir.Convolution):
if self.const_dict["CONV1X1_NHWC"] == "False":
IN_C = self.args_dict["IN_C"]
KERNEL_H = self.args_dict["KERNEL_H"]
KERNEL_W = self.args_dict["KERNEL_W"]
dilation_h = self.args_dict["dilation_h"]
dilation_w = self.args_dict["dilation_w"]
stride_wc = self.args_dict["stride_wc"]
stride_wh = self.args_dict["stride_wh"]
stride_ww = self.args_dict["stride_ww"]
stride_xc = self.args_dict["stride_xc"]
stride_xh = self.args_dict["stride_xh"]
stride_xw = self.args_dict["stride_xw"]
device = self.other_dict["device"]
if self.template_name == "triton_conv_delta_x":
assert "delta_x_ptr" not in self.args_dict.keys()
self.args_dict["delta_x_ptr"] = f"delta_x_{kernel_name}"
wrapper.writeline(
f"from {config.inductor_import}.triton_ops import _conv"
)
wrapper.writeline(
f"delta_x_{kernel_name} = _conv._delta_x_ptr("
f"{IN_C}, {KERNEL_H}, {KERNEL_W}, "
f"{dilation_h}, {dilation_w}, "
f"{stride_wc}, {stride_wh}, {stride_ww}, "
f"{stride_xc}, {stride_xh}, {stride_xw}, {device})"
)
# triton_conv_delta_x_hwc
else:
assert "delta_xh_ptr" not in self.args_dict.keys()
assert "delta_xw_ptr" not in self.args_dict.keys()
assert "delta_xc_ptr" not in self.args_dict.keys()
self.args_dict["delta_xh_ptr"] = f"delta_xh_{kernel_name}"
self.args_dict["delta_xw_ptr"] = f"delta_xw_{kernel_name}"
self.args_dict["delta_xc_ptr"] = f"delta_xc_{kernel_name}"
wrapper.writeline(
f"from {config.inductor_import}.triton_ops import _conv"
)
wrapper.writeline(
f"delta_xh_{kernel_name}, delta_xw_{kernel_name}, delta_xc_{kernel_name}"
f" = _conv._delta_x_ptr_hwc("
f"{IN_C}, {KERNEL_H}, {KERNEL_W}, "
f"{dilation_h}, {dilation_w}, "
f"{stride_wc}, {stride_wh}, {stride_ww}, "
f"{stride_xc}, {stride_xh}, {stride_xw}, {device})"
)
# else, delta_x_ptr is None
else:
assert "delta_x_ptr" not in self.args_dict.keys()
self.args_dict["delta_x_ptr"] = "None"
return
def gen_grid(self, name):
code = IndentedBuffer()
if isinstance(self.node, ir.Convolution):
BATCH = self.args_dict["BATCH"]
OUT_H = self.args_dict["OUT_H"]
OUT_W = self.args_dict["OUT_W"]
KERNEL_N = self.args_dict["KERNEL_N"]
code.splice(
f"""
def grid_{name}(META):
return (
triton.cdiv({BATCH} * {OUT_H} * {OUT_W}, META["BLOCK_M"]),
triton.cdiv({KERNEL_N}, META["BLOCK_N"]),
)
"""
)
if isinstance(self.node, ir.MatrixMultiply):
M = self.args_dict["M"]
N = self.args_dict["N"]
code.splice(
f"""
def grid_{name}(META):
return (
triton.cdiv({M}, META["BLOCK_M"]) * triton.cdiv({N}, META["BLOCK_N"]),
META["SPLIT_K"],
)
"""
)
return code.getvalue()
def call_kernel(self, wrapper, name: str):
# gen code to call kernel
# e.g.
# def grid(META):
# return (...)
# kernel1[grid](arg0, arg1, ...)
extra_args = ", ".join(self.extra_call_args)
self_args = ", ".join({**self.inout_dict, **self.args_dict}.values())
self_const_kwargs = ", ".join(f"{k}={v}" for k, v in self.const_dict.items())
args = self_args + (
", " + extra_args if extra_args and len(extra_args) > 0 else ""
)
args_kwargs = args + ", " + self_const_kwargs
lines = self.gen_grid(name).split("\n")
for l in lines:
wrapper.writeline(l)
wrapper.writeline(f"{name}[grid_{name}]({args_kwargs})")
def should_use_template(node: ir.ExternKernel):
template_kernels = [ir.Convolution, ir.MatrixMultiply]
if type(node) in template_kernels and ir.is_triton(node.get_device()):
if isinstance(node, ir.Convolution):
return node.kernel != "aten.convolution"
elif isinstance(node, ir.MatrixMultiply):
return node.kernel != "aten.mm.out"
return False
def template_can_fuse(snode1, snode2):
assert snode1.is_template()
if snode1.group != snode2.group:
return False
tiling = snode1.get_nodes()[0].node.get_template_tiling()
for node in snode2.get_nodes():
if not TritonKernel.is_compatible(tiling, node.get_ranges()):
return False
return True
def template_codegen(scheduler, scheduler_node, epilogue):
"""
codegen function for triton templates
scheduler: Scheduler
scheduler_node: ExternKernelSchedulerNode
"""
log.debug("template_codegen: %s -- %s", scheduler_node, epilogue)
wrapper = V.graph.wrapper_code
_, groups = scheduler_node.group
with TritonTemplateKernel(
scheduler_node.node, *scheduler_node.node.get_template_tiling()
) as kernel:
# map const args/ shape/ strides to kernel args
kernel.map_args()
# set self.args name to match the TritonTemplateKernel's args names
kernel.rename_vars()
# scheduler.pop_group will keep iterating all reachable fusable SchedulerNodes
assert type(kernel.node) in template_dict.keys()
kernel.store_buffer_names.add(scheduler_node.get_name())
for node in epilogue:
node.mark_run()
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
could_remove_kernel_buf = (
kernel.args.output_buffers[scheduler_node.get_name()] == "REMOVED"
)
kernel_buf_replace_name = None
if could_remove_kernel_buf:
for node in epilogue:
if not kernel.args.is_removed(node.get_name()):
kernel_buf_replace_name = node.get_name()
break
assert kernel_buf_replace_name is not None
kernel_name = "triton_template_" + wrapper.next_kernel_suffix()
# code gen kernel
wrapper.header.splice(
kernel.codegen_kernel(
kernel_name,
bool(epilogue),
could_remove_kernel_buf,
kernel_buf_replace_name,
)
)
# gen precompute tensor (like delta_x_ptr) if needed
kernel.precompute(wrapper, kernel_name)
# code gen call to kernel
kernel.call_kernel(wrapper, kernel_name)

View File

@ -272,6 +272,7 @@ class WrapperCodeGen(CodeGen):
import random
from torch import empty_strided, as_strided, device
from {codecache.__name__} import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
@ -299,19 +300,6 @@ class WrapperCodeGen(CodeGen):
"""
)
if config.triton.mm != "aten":
self.header.splice(
f"""
from {config.inductor_import}.triton_ops.autotune import mm_heuristics
from {config.inductor_import}.triton_ops.autotune import mm_autotune
"""
)
if config.triton.use_bmm:
self.header.writeline(
f"from {config.inductor_import}.triton_ops.batched_matmul import bmm_out as triton_bmm_out"
)
self.write_prefix()
for name, value in V.graph.constants.items():
@ -325,6 +313,21 @@ class WrapperCodeGen(CodeGen):
self.write_get_cuda_stream
)
@functools.lru_cache(None)
def add_import_once(line):
self.header.writeline(line)
self.add_import_once = add_import_once
self._metas = {}
def add_meta_once(self, meta):
meta = repr(meta)
if meta not in self._metas:
var = f"meta{len(self._metas)}"
self._metas[meta] = var
self.header.writeline(f"{var} = {meta}")
return self._metas[meta]
@cache_on_self
def get_output_refs(self):
return [x.codegen_reference() for x in V.graph.graph_outputs]

View File

@ -36,6 +36,15 @@ inplace_buffers = True
# codegen benchmark harness
benchmark_harness = True
# fuse pointwise into templates
epilogue_fusion = False
# do epilogue fusions before other fusions
epilogue_fusion_first = False
# enable slow autotuning passes to select algorithms
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
# control store vs recompute heuristic
# For fanouts, rematearialization can lead to exponential blowup. So, have
# smaller threshold
@ -143,12 +152,9 @@ class triton:
# Synchronize after every kernel launch, to help pinpoint bugs
debug_sync_kernel = False
# choose conv backend, "aten" or "triton" or "autotune"
# choose conv backend, "aten" or "triton"
convolution = "aten"
# choose mm backend, "aten" or "triton" or "autotune"
mm = "aten"
# Always load full blocks (rather than broadcasting inside the block)
# Set default as True because otherwise will encouter `map::at` error
# in triton if loading from 1-dim tensor using 2-dim pointer offset
@ -159,10 +165,9 @@ class triton:
# limit tiling dimensions
max_tiles = 2
# use triton.autotune?
autotune = True
use_bmm = False
# use triton.autotune for pointwise ops with complex layouts
# this should only be disabled for debugging/testing
autotune_pointwise = True
# should we stop a fusion to allow better tiling?
tiling_prevents_pointwise_fusion = True
@ -209,33 +214,25 @@ class trace:
class InductorConfigContext:
static_memory: bool
matmul_tune: str
matmul_padding: bool
triton_autotune: bool
triton_bmm: bool
triton_mm: str
max_autotune: bool
triton_convolution: str
rematerialize_threshold: int
rematerialize_acc_threshold: int
def _save(self):
self.static_memory = triton.cudagraphs
self.matmul_tune = triton.mm
self.matmul_padding = shape_padding
self.triton_autotune = triton.autotune
self.triton_bmm = triton.use_bmm
self.triton_mm = triton.mm
self.max_autotune = max_autotune
self.triton_convolution = triton.convolution
self.rematerialize_threshold = realize_reads_threshold
self.rematerialize_acc_threshold = realize_acc_reads_threshold
def _apply(self):
global shape_padding, realize_reads_threshold, realize_acc_reads_threshold, max_autotune
triton.cudagraphs = self.static_memory
triton.mm = self.matmul_tune
shape_padding = self.matmul_padding
triton.autotune = self.triton_autotune
triton.use_bmm = self.triton_bmm
triton.mm = self.triton_mm
max_autotune = self.max_autotune
triton.convolution = self.triton_convolution
realize_reads_threshold = self.rematerialize_threshold
realize_acc_reads_threshold = self.rematerialize_acc_threshold
@ -254,11 +251,7 @@ class InductorConfigContext:
self.static_memory = True
def max_autotune():
self.static_memory = False
self.matmul_padding = True
self.triton_convolution = "autotune"
self.triton_mm = "autotune"
self.matmul_padding = True
self.max_autotune = True
modes = {
x.__name__.replace("_", "-"): x

View File

@ -27,12 +27,10 @@ from torch.fx.passes.tools_common import legalize_graph
from . import config, ir # noqa: F811, this is needed
from .scheduler import (
BaseSchedulerNode,
ExternKernelSchedulerNode,
FusedSchedulerNode,
NopKernelSchedulerNode,
OutputNode,
SchedulerNode,
TemplateSchedulerNode,
)
from .utils import dynamo_config, dynamo_debug_utils, dynamo_utils
from .virtualized import V
@ -110,10 +108,10 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
group: Any = None
# create call_function node for each Buffer and Kernel
for snode in snodes:
if isinstance(snode, ExternKernelSchedulerNode):
if snode.is_extern():
node_type = "extern"
group = node_type
elif isinstance(snode, TemplateSchedulerNode):
elif snode.is_template():
node_type = "template"
group = node_type
elif isinstance(snode, NopKernelSchedulerNode):

View File

@ -165,14 +165,6 @@ def pad_dim(x, padded_length, dim):
@register_decomposition([aten.addmm])
def addmm(input, mat1, mat2, *, beta=1, alpha=1):
if config.triton.mm != "aten":
out = torch.mm(mat1, mat2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
out = out * alpha
if not isinstance(beta, numbers.Number) or beta != 1:
input = input * beta
return input + out
if (
config.shape_padding
and check_device(mat1, mat2)

View File

@ -120,6 +120,7 @@ class GraphLowering(torch.fx.Interpreter):
self.name = "GraphLowering"
self._can_use_cpp_wrapper = config.cpp_wrapper
self.graph_id = graph_id
self.scheduler = None
def get_dtype(self, buffer_name: str):
if buffer_name in self.constants:
@ -175,10 +176,7 @@ class GraphLowering(torch.fx.Interpreter):
def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
if isinstance(buffer, ir.ExternKernel):
if not isinstance(
buffer,
(ir.MatrixMultiply, ir.BatchMatrixMultiply, ir.MatrixMultiplyAdd),
):
if not getattr(buffer, "cpp_kernel", False):
self.disable_cpp_wrapper("ExternKernel")
def register_buffer(self, buffer: ir.ComputedBuffer):
@ -296,6 +294,7 @@ class GraphLowering(torch.fx.Interpreter):
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
log.exception("Error from lowering")
raise LoweringException(e, target, args, kwargs) from e
def get_attr(self, target, args, kwargs):
@ -471,6 +470,7 @@ class GraphLowering(torch.fx.Interpreter):
self.init_wrapper_code()
self.scheduler = Scheduler(self.buffers)
assert self.scheduler is not None # mypy can't figure this out
self.scheduler.codegen()
assert self.wrapper_code is not None
return self.wrapper_code.generate()

View File

@ -1578,6 +1578,9 @@ class Constant(BaseConstant):
return loader
def realize(self):
pass
@dataclasses.dataclass
class IndexingConstant(BaseConstant):
@ -1698,6 +1701,24 @@ class Layout(IRNode):
class FixedLayout(Layout):
"""A Tensor layout we cannot change"""
def __init__(
self,
device: torch.device,
dtype: torch.dtype,
size: List[Expr],
stride: List[Expr] = None,
offset: Expr = Integer(0),
):
if stride is None:
stride = FlexibleLayout.contiguous_strides(size)
super().__init__(
device,
dtype,
size,
stride,
offset,
)
def make_indexer(self):
"""A closure containing math to read a given element"""
@ -2250,6 +2271,58 @@ class ComputedBuffer(Buffer):
return self.data.constant_to_device(device)
class TemplateBuffer(Buffer):
"""
Represents a Triton (in the futurue other type) of template operator
that we can fuse an epilogue onto.
"""
def __init__(self, layout, inputs, make_kernel_render):
super().__init__(name=None, layout=layout)
self.inputs = InputsKernel.unwrap_storage(inputs)
self.make_kernel_render = make_kernel_render
self.name = V.graph.register_buffer(self)
def get_read_writes(self):
return self.normalized_read_writes()
@cache_on_self
def normalized_read_writes(self):
name = self.get_name()
indexer = self.layout.make_indexer()
def dummy(index, rindex):
assert len(rindex) == 0
return ops.store(name, indexer(index), "fake")
deps = dependencies.extract_read_writes(
dummy, self.get_size(), (), normalize=True
)
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
return deps
def get_reduction_size(self):
return 1
def get_reduction_type(self):
return None
def is_no_op(self):
return False
def should_allocate(self):
return True
def simplify_and_reorder(self):
return (
(
self.get_size(),
(),
),
None,
)
@dataclasses.dataclass
class InputsKernel(Buffer):
inputs: List[Buffer]
@ -2688,12 +2761,24 @@ class ExternKernelOut(ExternKernel):
self.cpp_kernel,
)
def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None):
def __init__(
self,
layout,
inputs,
constant_args=(),
kwargs=None,
output_view=None,
kernel=None,
cpp_kernel=None,
):
super().__init__(
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
)
self.output_view = output_view
self.name = V.graph.register_buffer(self)
if kernel is not None:
self.kernel = kernel
self.cpp_kernel = cpp_kernel
def should_allocate(self):
return True
@ -2784,198 +2869,6 @@ class IndexPutFallback(ExternKernel):
self.name = V.graph.register_buffer(self)
class MatrixMultiply(ExternKernelOut):
kernel = "aten.mm.out"
cpp_kernel = "at::mm_out"
def __init__(
self, layout, inputs, constant_args=(), output_view=None, kernel="aten.mm.out"
):
super().__init__(layout, inputs, constant_args, output_view)
self.kernel = kernel
@classmethod
def create(cls, a, b):
*m, k1 = a.get_size()
k2, n = b.get_size()
V.graph.sizevars.guard_equals(k1, k2)
a = cls.realize_input(a)
b = cls.realize_input(b)
if len(m) != 1 and not a.get_layout().is_contiguous():
a = cls.copy_input(a)
else:
a = cls.require_stride1(a)
b = cls.require_stride1(b)
# choose runtime kernel
config_mm = config.triton.mm
# default kernel is aten
kernel = "aten.mm.out"
if config_mm == "aten":
kernel = "aten.mm.out"
elif config_mm == "triton" and a.get_device().type == "cuda":
kernel = "triton_ops.matmul_out"
elif config_mm == "autotune":
from .codegen.autotuner import tuned_mm
kernel = tuned_mm(
a.get_size(),
b.get_size(),
a.get_stride(),
b.get_stride(),
a.get_device(),
a.get_dtype(),
)
return MatrixMultiply(
layout=FlexibleLayout(
device=a.get_device(),
dtype=a.get_dtype(),
size=list(m) + [n],
),
inputs=[a, b],
kernel=kernel,
)
def get_template_tiling(self):
tile1, tile2 = self.get_size()
return (
tile1,
tile2,
sympy.Integer(1),
)
def map_args(self):
# a, b
in_args = [x.codegen_reference() for x in self.inputs]
# const_args = self.constant_args
inout_dict = OrderedDict(
[
("A", f"{in_args[0]}"),
("B", f"{in_args[1]}"),
("C", f"{self.get_name()}"),
]
)
# batch==1 bmm->mm
if len(self.get_stride()) == 3:
assert self.get_size()[0] == 1
stride_cm = self.get_stride()[1]
stride_cn = self.get_stride()[2]
else:
stride_cm = self.get_stride()[0]
stride_cn = self.get_stride()[1]
args_dict = OrderedDict(
[
("M", f"{self.inputs[0].get_size()[0]}"),
("N", f"{self.inputs[1].get_size()[1]}"),
("K", f"{self.inputs[0].get_size()[1]}"),
("stride_am", f"{self.inputs[0].get_stride()[0]}"),
("stride_ak", f"{self.inputs[0].get_stride()[1]}"),
("stride_bk", f"{self.inputs[1].get_stride()[0]}"),
("stride_bn", f"{self.inputs[1].get_stride()[1]}"),
("stride_cm", f"{stride_cm}"),
("stride_cn", f"{stride_cn}"),
]
)
# accumulator types
ACC_TYPE = (
"tl.float32"
if self.inputs[0].get_dtype()
in [torch.float16, torch.bfloat16, torch.float32]
else "tl.int32"
)
# dict for tl.constexpr
const_dict = OrderedDict(
[
("GROUP_M", "8"),
("ACC_TYPE", ACC_TYPE),
("allow_tf32", f"{torch.backends.cuda.matmul.allow_tf32}"),
]
)
other_dict = OrderedDict()
return inout_dict, args_dict, const_dict, other_dict
class MatrixMultiplyAdd(ExternKernelOut):
def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None):
super().__init__(layout, inputs, constant_args, kwargs or {}, output_view)
self.kernel = "aten.addmm.out"
self.cpp_kernel = "at::addmm_out"
self.ordered_kwargs_for_cpp_kernel = ["beta", "alpha"]
@classmethod
def create(cls, inp, a, b, beta, alpha):
m, k1 = a.get_size()
k2, n = b.get_size()
V.graph.sizevars.guard_equals(k1, k2)
inp = cls.realize_input(inp)
a = cls.realize_input(a)
b = cls.realize_input(b)
a = cls.require_stride1(a)
b = cls.require_stride1(b)
return MatrixMultiplyAdd(
layout=FlexibleLayout(
device=a.get_device(),
dtype=a.get_dtype(),
size=[m] + [n],
),
inputs=[inp, a, b],
kwargs={"beta": beta, "alpha": alpha},
)
class BatchMatrixMultiply(ExternKernelOut):
kernel = "aten.bmm.out"
cpp_kernel = "at::bmm_out"
def __init__(self, layout, inputs, constant_args=(), output_view=None):
super().__init__(layout, inputs, constant_args, output_view)
if (
config.triton.use_bmm
and len(inputs) > 0
and inputs[0].get_device().type == "cuda"
):
self.kernel = "triton_bmm_out"
@classmethod
def create(cls, a, b):
b1, m, k1 = a.get_size()
b2, k2, n = b.get_size()
b3 = V.graph.sizevars.guard_equals(b1, b2)
V.graph.sizevars.guard_equals(k1, k2)
a = cls.require_stride1(cls.realize_input(a))
b = cls.require_stride1(cls.realize_input(b))
output_layout = FlexibleLayout(
device=a.get_device(),
dtype=a.get_dtype(),
size=[b3, m, n],
).as_fixed()
if b3 == 1:
# convert to normal mm
data = MatrixMultiply(
layout=output_layout.as_fixed(),
inputs=[SqueezeView.create(a, dim=0), SqueezeView.create(b, dim=0)],
)
data.output_view = ReinterpretView(
data,
FlexibleLayout(
device=a.get_device(),
dtype=a.get_dtype(),
size=[m, n],
).as_fixed(),
)
else:
data = BatchMatrixMultiply(
layout=output_layout,
inputs=[a, b],
)
return data
class DeviceCopy(ExternKernelOut):
@classmethod
def create(cls, x, device):
@ -3862,7 +3755,14 @@ class StorageBox(MutableBox):
def realize(self):
if isinstance(
self.data, (ComputedBuffer, InputsKernel, InputBuffer, ReinterpretView)
self.data,
(
ComputedBuffer,
InputsKernel,
InputBuffer,
ReinterpretView,
TemplateBuffer,
),
):
return self.data.get_name()
assert isinstance(self.data, (Pointwise, Reduction)), type(self.data)

View File

View File

@ -0,0 +1,126 @@
import torch
from ..lowering import register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
TritonTemplate,
)
from ..utils import ceildiv as cdiv, use_triton_template
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
aten = torch.ops.aten
def bmm_grid(b, m, n, meta):
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
bmm_template = TritonTemplate(
name="bmm",
grid=bmm_grid,
source=r"""
{{def_kernel("A", "B")}}
M = {{size("A", -2)}}
N = {{size("B", -1)}}
K = {{size("A", -1)}}
stride_aq = {{stride("A", 0)}}
stride_am = {{stride("A", 1)}}
stride_ak = {{stride("A", 2)}}
stride_bq = {{stride("B", 0)}}
stride_bk = {{stride("B", 1)}}
stride_bn = {{stride("B", 2)}}
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
idx_q = tl.program_id(1) # batch dimension for BMM
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_q = tl.program_id(1) # batch dimension for BMM
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
""",
)
aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
@register_lowering(aten.bmm)
def tuned_bmm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
# options to tune from
choices = [aten_bmm.bind((mat1, mat2), layout)]
if use_triton_template(layout):
for config in mm_configs():
choices.append(
bmm_template.generate(
(mat1, mat2),
layout,
**mm_options(config, k, layout),
)
)
return autotune_select_algorithm(choices, [mat1, mat2], layout)
# Don't register this since it is slower than decomposing it
# @register_lowering(aten.baddbmm)
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
# options to tune from
choices = [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
if use_triton_template(layout):
for config in mm_configs():
choices.append(
bmm_template.generate(
(inp, mat1, mat2),
layout,
**mm_options(config, k, layout),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)
)
return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)

View File

@ -0,0 +1,121 @@
import logging
import torch
from ..lowering import register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
TritonTemplate,
)
from ..utils import use_triton_template
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_grid, mm_options
log = logging.getLogger(__name__)
aten = torch.ops.aten
mm_template = TritonTemplate(
name="mm",
grid=mm_grid,
source=r"""
{{def_kernel("A", "B")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
""",
)
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
aten_addmm = ExternKernelChoice(torch.addmm, "at::addmm_out")
@register_lowering(aten.mm)
def tuned_mm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
# options to tune from
choices = [aten_mm.bind((mat1, mat2), layout)]
if use_triton_template(layout):
for config in mm_configs():
choices.append(
mm_template.generate(
(mat1, mat2),
layout,
**mm_options(config, k, layout),
)
)
return autotune_select_algorithm(choices, [mat1, mat2], layout)
@register_lowering(aten.addmm)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
# don't expand inp to make sure fused addmm from cublasLt is used
if not use_triton_template(layout):
choices = [aten_addmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)
# TODO this is not quite fair benchmarking because we won't use fused cublasLt addmm
# options to tune from
choices = [
aten_addmm.bind((inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta)
]
if use_triton_template(layout):
for config in mm_configs():
choices.append(
mm_template.generate(
(inp_expanded, mat1, mat2),
layout,
**mm_options(config, k, layout),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)
)
return autotune_select_algorithm(choices, [inp_expanded, mat1, mat2], layout)

View File

@ -0,0 +1,125 @@
import functools
import logging
import sympy
import torch
from torch._inductor.select_algorithm import realize_inputs
from torch._inductor.virtualized import V
from ..utils import ceildiv as cdiv
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def mm_configs():
import triton
return [
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=3, num_warps=4
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=5, num_warps=8
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=5, num_warps=8
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, num_stages=1, num_warps=2
),
]
def mm_grid(m, n, meta):
"""
The CUDA grid size for matmul triton templates.
"""
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
def acc_type(dtype):
if dtype in (torch.float16, torch.bfloat16):
return "tl.float32"
return f"tl.{dtype}".replace("torch.", "")
def mm_options(config, sym_k, layout):
"""
Common options to matmul triton templates.
"""
even_k_symbolic = (
# it isn't worth guarding on this
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
== config.kwargs["BLOCK_K"]
)
return dict(
GROUP_M=8,
EVEN_K=even_k_symbolic,
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
ACC_TYPE=acc_type(layout.dtype),
num_stages=config.num_stages,
num_warps=config.num_warps,
**config.kwargs,
)
def mm_args(mat1, mat2, *others, layout=None):
"""
Common arg processing for mm,bmm,addmm,etc
"""
mat1, mat2 = realize_inputs(mat1, mat2)
*b1, m, k1 = mat1.get_size()
*b2, k2, n = mat2.get_size()
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
k = V.graph.sizevars.guard_equals(k1, k2)
if layout is None:
from torch._inductor.ir import FixedLayout
layout = FixedLayout(
mat1.get_device(),
mat1.get_dtype(),
[*b, m, n],
)
from ..lowering import expand
others = [realize_inputs(expand(x, layout.size)) for x in others]
return [m, n, k, layout, mat1, mat2, *others]
def addmm_epilogue(dtype, alpha, beta):
def epilogue(acc, bias):
if alpha != 1:
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
if beta != 1:
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
return V.ops.add(acc, bias)
return epilogue

View File

@ -558,6 +558,7 @@ def trunc(x):
@register_lowering(aten.expand, type_promotion_kind=None)
def expand(x, sizes):
(x,) = promote_constants([x])
if isinstance(x, ir.BaseConstant):
return ExpandView.create(x, tuple(sizes))
assert isinstance(x, TensorBox)
@ -837,21 +838,6 @@ def glu(x, dim=-1):
return mul(a, sigmoid(b))
@register_lowering(aten.mm)
def mm(a: TensorBox, b: TensorBox):
return TensorBox.create(ir.MatrixMultiply.create(a, b))
@register_lowering(aten.addmm)
def addmm(inp: TensorBox, a: TensorBox, b: TensorBox, beta=1, alpha=1):
return TensorBox.create(ir.MatrixMultiplyAdd.create(inp, a, b, beta, alpha))
@register_lowering(aten.bmm)
def bmm(a: TensorBox, b: TensorBox):
return TensorBox.create(ir.BatchMatrixMultiply.create(a, b))
def register_onednn_fusion_ops():
if torch._C.has_mkldnn:
@ -3731,3 +3717,20 @@ def foobar(self, *args, **kwargs):
def _realize(x):
x.realize()
return clone(x)
def _import_kernels():
"""
Need to make sure all these get registered in the lowers dict
"""
import importlib
import os
from . import kernel
for filename in sorted(os.listdir(os.path.dirname(kernel.__file__))):
if filename.endswith(".py") and filename[0] != "_":
importlib.import_module(f"{kernel.__name__}.{filename[:-3]}")
_import_kernels()

View File

@ -6,14 +6,14 @@ import logging
import os
import pprint
import textwrap
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Set
import sympy
import torch
from . import config, dependencies, ir, metrics
from .dependencies import MemoryDep, StarDep
from .dependencies import StarDep
from .sizevars import SimplifyIndexing
from .utils import cache_on_self, cmp, dynamo_utils, has_triton
from .virtualized import V
@ -116,6 +116,9 @@ class BaseSchedulerNode:
def get_mutations(self):
return self.node.get_mutation_names()
def has_aliasing_or_mutation(self):
return bool(self.get_aliases() or self.get_mutations())
def set_read_writes(self, rw: dependencies.ReadWrites):
self.read_writes: dependencies.ReadWrites = rw
self.unmet_dependencies = self.read_writes.reads
@ -162,9 +165,7 @@ class BaseSchedulerNode:
return False
def allocate(self):
from .codegen.triton_template import should_use_template
if self.node.should_allocate() or should_use_template(self.node):
if self.node.should_allocate():
# if self.node should allocate or
# if self.node is generated by TritonKernelTemplates
# because Triton kernel could not allocate tensor itself
@ -223,32 +224,6 @@ class ExternKernelSchedulerNode(BaseSchedulerNode):
return True
class TemplateSchedulerNode(BaseSchedulerNode):
def __init__(self, scheduler: "Scheduler", node: ir.ExternKernel, group_fn):
super().__init__(scheduler, node)
(self._sizes, self._stride) = node.get_group_stride()
self.group = (node.get_device(), group_fn(self._sizes))
self.set_read_writes(node.get_read_writes())
self.update_dep_type()
def is_template(self):
return True
def update_dep_type(self):
assert len(self.read_writes.writes) == 1
write = self.read_writes.writes.pop()
if isinstance(write, StarDep):
name = write.name
canonicalized_index, canonicalized_size = self.node.canonicalize()
new_dep = MemoryDep(name, canonicalized_index, canonicalized_size)
self.read_writes.writes.add(new_dep)
else:
self.read_writes.writes.add(write)
def get_ranges(self):
return self._sizes
class NopKernelSchedulerNode(BaseSchedulerNode):
pass
@ -263,9 +238,15 @@ class SchedulerNode(BaseSchedulerNode):
self.group = (node.get_device(), group_fn(self._sizes))
self.set_read_writes(
dependencies.extract_read_writes(self._body, *self._sizes, normalize=True)
)
if self.is_template():
self.set_read_writes(node.normalized_read_writes())
else:
self.set_read_writes(
dependencies.extract_read_writes(
self._body, *self._sizes, normalize=True
)
)
if self.is_reduction():
# reduction has last (reduced) dim in its sizes, and some
# downstream dependencies get confused by it
@ -303,7 +284,10 @@ class SchedulerNode(BaseSchedulerNode):
return self._sizes
def is_reduction(self):
return bool(self.node.data.get_reduction_type())
return bool(self.node.get_reduction_type())
def is_template(self):
return isinstance(self.node, ir.TemplateBuffer)
def allocate(self):
if (
@ -313,8 +297,7 @@ class SchedulerNode(BaseSchedulerNode):
):
return super().allocate()
if config.inplace_buffers:
from .codegen.triton_template import should_use_template
if config.inplace_buffers and getattr(V.kernel, "mutations", None) is not None:
from .codegen.wrapper import buffer_reuse_key
ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name)
@ -338,7 +321,6 @@ class SchedulerNode(BaseSchedulerNode):
input_node.node.get_layout(),
(ir.MultiOutputLayout, ir.MutationLayout, ir.AliasedLayout),
)
and not should_use_template(input_node.node)
and buffer_reuse_key(input_node.node)
== buffer_reuse_key(self.node)
):
@ -398,7 +380,7 @@ class SchedulerNode(BaseSchedulerNode):
return dependencies.extract_read_writes(fn, sizes)
def can_inplace(self, read_dep: dependencies.MemoryDep):
if self.get_aliases():
if self.get_aliases() or self.is_template():
return False
if len(self.read_writes.writes) == 1 and hasattr(read_dep, "index"):
write_dep = next(iter(self.read_writes.writes))
@ -482,6 +464,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
def get_device(self):
return self.group[0]
@cache_on_self
def has_aliasing_or_mutation(self):
return any(x.has_aliasing_or_mutation() for x in self.snodes)
# None of these need to be implemented, as a FusedSchedulerNode is just an
# abstraction for scheduling purposes
def update_mutated_names(self, renames: Dict[str, str]):
@ -561,8 +547,6 @@ class NodeUser:
class Scheduler:
@dynamo_utils.dynamo_timed
def __init__(self, nodes):
from .codegen.triton_template import should_use_template
super(Scheduler, self).__init__()
self.backends = {}
@ -577,12 +561,9 @@ class Scheduler:
), "All nodes passed to scheduling must have an origin"
if node.is_no_op():
self.nodes.append(NopKernelSchedulerNode(self, node))
elif isinstance(node, ir.ComputedBuffer):
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
group_fn = self.get_backend(node.get_device()).group_fn
self.nodes.append(SchedulerNode(self, node, group_fn))
elif isinstance(node, ir.ExternKernel) and should_use_template(node):
group_fn = self.get_backend(node.get_device()).group_fn
self.nodes.append(TemplateSchedulerNode(self, node, group_fn))
elif isinstance(node, ir.ExternKernel):
self.nodes.append(ExternKernelSchedulerNode(self, node))
else:
@ -901,6 +882,12 @@ class Scheduler:
return False # node2 must go before node1
if node2.is_template():
return False # only epilogues
if node1.is_template() and (
node2.has_aliasing_or_mutation()
or node2.is_reduction()
or not config.epilogue_fusion
):
return False
device = node1.get_device()
if device != node2.get_device():
@ -919,14 +906,8 @@ class Scheduler:
# node2 depends on node1 outputs
if not self.can_fuse_vertical(node1, node2):
return False
if node1.is_template():
from .codegen.triton_template import template_can_fuse
return template_can_fuse(node1, node2)
return self.get_backend(device).can_fuse_vertical(node1, node2)
else: # nodes don't depend on each other, but may have common reads
if node1.is_template():
return False
return self.get_backend(device).can_fuse_horizontal(node1, node2)
def can_fuse_vertical(self, node1, node2):
@ -981,6 +962,7 @@ class Scheduler:
abs(node2.min_order - node1.max_order),
)
return (
node1.is_template() == config.epilogue_fusion_first and memory_score > 0,
node1.is_reduction() == node2.is_reduction() and memory_score > 0,
memory_score,
proximity_score,
@ -1074,16 +1056,6 @@ class Scheduler:
node.codegen(V.graph.wrapper_code)
self.free_buffers()
def codegen_template_call(
self, scheduler_node: Union[FusedSchedulerNode, TemplateSchedulerNode]
):
from .codegen.triton_template import template_codegen
node, *epilogue = scheduler_node.get_nodes()
node.allocate()
template_codegen(self, node, epilogue)
self.free_buffers()
def create_backend(self, device: torch.device):
assert (
device.type != "cuda" or device.index is not None
@ -1141,7 +1113,8 @@ class Scheduler:
self.buffer_names_to_free.update(node.last_usage)
if node.is_template():
self.codegen_template_call(node)
node, *epilogue = node.get_nodes()
self.get_backend(device).codegen_template(node, epilogue)
elif node.is_extern():
self.codegen_extern_call(node)
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):

View File

@ -0,0 +1,681 @@
import builtins
import functools
import inspect
import itertools
import logging
import sys
import textwrap
from io import StringIO
from typing import Any, List
from unittest.mock import patch
import sympy
import torch
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import counters, identity
from . import config, ir
from .codecache import code_hash, DiskCache, PyCodeCache
from .codegen.common import IndentedBuffer
from .codegen.triton import config_of, signature_of, texpr, TritonKernel, TritonPrinter
from .utils import do_bench, sympy_dot, sympy_product
from .virtualized import V
log = logging.getLogger(__name__)
# correctness checks struggle with fp16/tf32
VERIFY = False # dict(atol=1, rtol=0.05)
PRINT_AUTOTUNE = True
class KernelNamespace:
pass
# these objects are imported from the generated wrapper code
template_kernels = KernelNamespace()
extern_kernels = KernelNamespace()
class TritonTemplateKernel(TritonKernel):
def __init__(
self,
kernel_name,
input_nodes,
output_node,
defines,
num_stages,
num_warps,
grid_fn,
meta,
call_sizes,
use_jit=True,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
):
super().__init__(sympy_product(output_node.get_size()), sympy.Integer(1))
self.input_nodes = input_nodes
self.output_node = output_node
self.named_input_nodes = {}
self.defines = defines
self.kernel_name = kernel_name
self.template_mask = None
self.use_jit = use_jit
self.num_stages = num_stages
self.num_warps = num_warps
self.grid_fn = grid_fn
self.meta = meta
self.call_sizes = call_sizes
# for templates with fixed epilogues
self.prefix_args = prefix_args
self.suffix_args = suffix_args
self.epilogue_fn = epilogue_fn
def jit_line(self):
if self.use_jit:
return "@triton.jit"
argdefs, _, signature = self.args.python_argdefs()
triton_meta = {
"signature": dict(enumerate(map(signature_of, signature))),
"device": V.graph.scheduler.current_device.index,
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]
return (
f"@template(num_stages={self.num_stages}, num_warps={self.num_warps}, meta={triton_meta!r})\n"
+ "@triton.jit"
)
def def_kernel(self, *argnames):
"""
Hook called from template code to generate function def and
needed args.
"""
assert all(isinstance(x, str) for x in argnames)
renames = IndentedBuffer(initial_indent=1)
named_args = self.input_nodes[
self.prefix_args : len(self.input_nodes) - self.suffix_args
]
assert len(argnames) == len(named_args), (
len(argnames),
len(named_args),
self.prefix_args,
len(self.input_nodes),
)
for input_node in self.input_nodes[: self.prefix_args]:
# get args in correct order
self.args.input(input_node.get_name())
for name, input_node in zip(argnames, named_args):
arg_name = f"arg_{name}"
self.named_input_nodes[name] = input_node
self.args.input_buffers[input_node.get_name()] = arg_name
if input_node.get_layout().offset == 0:
renames.writeline(f"{name} = {arg_name}")
else:
offset = texpr(self.rename_indexing(input_node.get_layout().offset))
renames.writeline(f"{name} = {arg_name} + {offset}")
for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
# get args in correct order
self.args.input(input_node.get_name())
arg_defs, *_ = self.args.python_argdefs()
return "\n".join(
[
"import triton.language as tl",
"import triton",
f"from {config.inductor_import}.triton_ops.autotune import template",
f"from {config.inductor_import}.utils import instance_descriptor",
"",
self.jit_line(),
f"def {self.kernel_name}({', '.join(arg_defs)}):",
self.defines,
renames.getvalue(),
]
)
def size(self, name: str, index: int):
"""
Hook called from template code to get the size of an arg.
Will add needed args to pass it in if it is dynamic.
"""
assert isinstance(name, str)
assert isinstance(index, int)
val = self.named_input_nodes[name].get_size()[index]
return texpr(self.rename_indexing(val))
def stride(self, name, index):
"""
Hook called from template code to get the stride of an arg.
Will add needed args to pass it in if it is dynamic.
"""
assert isinstance(name, str)
assert isinstance(index, int)
val = self.named_input_nodes[name].get_stride()[index]
return texpr(self.rename_indexing(val))
def store_output(self, indices, val, mask):
"""
Hook called from template code to store the final output
(if the buffer hasn't been optimized away), then append any
epilogue fusions.
"""
assert isinstance(indices, (list, tuple))
assert isinstance(val, str)
assert isinstance(mask, str)
if self.template_mask is None:
indices = list(map(TritonPrinter.paren, indices))
index_symbols = [sympy.Symbol(x) for x in indices]
lengths = [
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
]
assert len(indices) == len(lengths)
# glue to make generated code use same indexing from template
for name, range_tree_entry in zip(
indices, self.range_trees[0].construct_entries(lengths)
):
range_tree_entry.set_name(name)
contiguous_index = sympy_dot(
ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
)
self.body.writeline("xindex = " + texpr(contiguous_index))
self.range_trees[0].lookup(
sympy.Integer(1), sympy_product(lengths)
).set_name("xindex")
self.template_mask = mask
self.template_indices = indices
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
if output_index == contiguous_index:
output_index = sympy.Symbol("xindex")
epilogue_args = [val]
for input_node in itertools.chain(
self.input_nodes[: self.prefix_args],
self.input_nodes[len(self.input_nodes) - self.suffix_args :],
):
input_node.freeze_layout()
epilogue_args.append(input_node.make_loader()(index_symbols))
V.ops.store(
self.output_node.get_name(),
output_index,
self.epilogue_fn(*epilogue_args),
)
assert self.template_mask == mask
self.codegen_body()
return textwrap.indent(self.body.getvalue(), " ").strip()
def make_load(self, name, indices, mask):
"""
Optional helper called from template code to generate the code
needed to load from an tensor.
"""
assert isinstance(indices, (list, tuple))
assert isinstance(name, str)
assert isinstance(mask, str)
stride = self.named_input_nodes[name].get_stride()
indices = list(map(TritonPrinter.paren, indices))
assert len(indices) == len(stride)
index = " + ".join(
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
)
return f"tl.load({name} + ({index}), {mask})"
def template_env(self):
"""
Generate the namespace visible in the template.
"""
return {
fn.__name__: fn
for fn in [
self.def_kernel,
self.size,
self.stride,
self.store_output,
self.make_load,
]
}
def indexing(
self,
index: sympy.Expr,
*,
copy_shape=None,
dense_indexing=False,
):
"""
Override the default indexing to use our custom mask and force
dense indexing.
"""
result, *mask = super().indexing(
index,
dense_indexing=False,
copy_shape=copy_shape,
override_mask=self.template_mask,
)
result += f" + tl.zeros({self.template_mask}.shape, tl.int32)"
return (result, *mask)
def initialize_range_tree(self, pid_cache):
super().initialize_range_tree(pid_cache)
# ignore default codegen
self.body.clear()
self.indexing_code.clear()
def call_kernel(self, code, name: str):
_, call_args, _ = self.args.python_argdefs()
for i in range(len(call_args)):
if V.graph.is_unspec_arg(call_args[i]):
call_args[i] = call_args[i] + ".item()"
call_args = ", ".join(call_args)
stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
V.graph.wrapper_code.add_import_once(f"import {self.grid_fn.__module__}")
meta = V.graph.wrapper_code.add_meta_once(self.meta)
grid_call = [texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes] + [
meta
]
grid_call = (
f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
)
code.writeline(
f"{name}.run({call_args}, grid={grid_call}, stream={stream_name})"
)
@functools.lru_cache(None)
def _jinja2_env():
try:
import jinja2
return jinja2.Environment(
undefined=jinja2.StrictUndefined,
)
except ImportError:
return None
class TritonTemplate:
index_counter = itertools.count()
all_templates = dict()
@staticmethod
def _template_from_string(source):
env = _jinja2_env()
if env is not None:
return env.from_string(source)
return None
def __init__(self, name: str, grid: Any, source: str, debug=False):
super().__init__()
self.name = name
self.grid = grid
self.template = self._template_from_string(source)
assert name not in self.all_templates, "duplicate template name"
self.all_templates[name] = self
self.debug = debug
def generate(
self,
input_nodes,
layout,
num_stages,
num_warps,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
**kwargs,
):
assert self.template, "requires jinja2"
defines = StringIO()
for name, val in kwargs.items():
defines.write(f" {name} : tl.constexpr = {val}\n")
defines = defines.getvalue()
fake_out = ir.Buffer("buf_out", layout)
kernel_name = f"triton_{self.name}"
kernel_options = dict(
input_nodes=input_nodes,
defines=defines,
num_stages=num_stages,
num_warps=num_warps,
grid_fn=self.grid,
meta=kwargs,
call_sizes=layout.size,
prefix_args=prefix_args,
suffix_args=suffix_args,
epilogue_fn=epilogue_fn,
)
with patch.object(
V.graph, "get_dtype", self.fake_get_dtype(fake_out)
), TritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
use_jit=True,
**kernel_options,
) as kernel:
# need to do call render twice to get all the needed args right
self.template.render(
**kernel.template_env(),
**kwargs,
)
code = self.template.render(
**kernel.template_env(),
**kwargs,
)
if self.debug:
print("Generated Code:\n", code)
mod = PyCodeCache.load(code)
run = getattr(mod, kernel_name).run
_, call_args, _ = kernel.args.python_argdefs()
expected_args = [x.get_name() for x in input_nodes] + [fake_out.get_name()]
assert list(call_args) == expected_args, (call_args, expected_args)
extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, call_args[len(expected_args) :])
)
assert not extra_args, "TODO: dynamic shapes"
def call(*args, out):
return run(
*args,
out,
*extra_args,
grid=self.grid(*out.size(), kwargs),
num_stages=num_stages,
num_warps=num_warps,
)
call.key = mod.key
call.__file__ = mod.__file__
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
setattr(template_kernels, kernel_hash_name, call)
def make_kernel_render(out_node):
kernel = TritonTemplateKernel(
kernel_name="KERNEL_NAME",
output_node=out_node,
use_jit=False,
**kernel_options,
)
render = functools.partial(
self.template.render,
**kernel.template_env(),
**kwargs,
)
return kernel, render
return TritonTemplateCaller(
kernel_hash_name, input_nodes, layout, make_kernel_render
)
@staticmethod
def fake_get_dtype(fake_out):
_get_dtype_real = V.graph.get_dtype
def get_dtype(name):
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
class ExternKernelChoice:
def __init__(self, kernel, cpp_kernel=None, *, name=None):
super().__init__()
name = name or kernel.__name__
assert callable(kernel)
assert not hasattr(extern_kernels, name), "duplicate extern kernel"
self.name = name
self.cpp_kernel = cpp_kernel
setattr(extern_kernels, name, kernel)
def to_callable(self):
return getattr(extern_kernels, self.name)
def call_name(self):
return f"extern_kernels.{self.name}"
@functools.lru_cache(None)
def hash_key(self):
fn = self.to_callable()
parts = [
self.name,
getattr(fn, "__name__", ""),
getattr(fn, "__module__", ""),
]
try:
parts.append(inspect.getsource(fn))
except Exception:
pass
return code_hash("-".join(parts))
def bind(self, input_nodes, layout, **kwargs):
return ExternKernelCaller(self, input_nodes, layout, kwargs)
class ChoiceCaller:
def __init__(self, name, input_nodes, layout):
super().__init__()
self.name = name
self.layout = layout
self.input_nodes = input_nodes
class TritonTemplateCaller(ChoiceCaller):
def __init__(self, name, input_nodes, layout, make_kernel_render):
super().__init__(name, input_nodes, layout)
self.make_kernel_render = make_kernel_render
def __str__(self):
return f"TritonTemplateCaller({self.to_callable().__file__})"
def call_name(self):
return f"template_kernels.{self.name}"
def to_callable(self):
return getattr(template_kernels, self.name)
def hash_key(self):
return self.to_callable().key
def output_node(self):
return ir.TensorBox.create(
ir.TemplateBuffer(
layout=self.layout,
inputs=self.input_nodes,
make_kernel_render=self.make_kernel_render,
)
)
class ExternKernelCaller(ChoiceCaller):
def __init__(self, choice: ExternKernelChoice, input_nodes, layout, kwargs=None):
super().__init__(choice.name, input_nodes, layout)
self.choice = choice
self.kwargs = kwargs or {}
def to_callable(self):
fn = self.choice.to_callable()
if self.kwargs:
return functools.partial(fn, **self.kwargs)
else:
return fn
def hash_key(self):
return "/".join(
[
self.choice.hash_key(),
repr(self.kwargs),
]
)
def output_node(self):
return ir.TensorBox.create(
ir.ExternKernelOut(
layout=self.layout,
inputs=self.input_nodes,
kernel=self.choice.call_name(),
cpp_kernel=self.choice.cpp_kernel,
kwargs=self.kwargs,
)
)
class AlgorithmSelectorCache(DiskCache):
def __call__(self, choices: List[ChoiceCaller], input_nodes, layout):
if len(choices) == 1:
return choices[0].output_node()
def autotune():
benchmark_fn = self.make_benchmark_fn(choices, input_nodes, layout)
timings = {}
for choice in choices:
try:
timings[choice] = benchmark_fn(
choice.to_callable(), isinstance(choice, ExternKernelCaller)
)
except RuntimeError as e:
if "invalid argument" in str(e):
msg = textwrap.dedent(
f"""
{e}
From choice {choices.index(choice)}: {choice}
This may mean this GPU is too small for max_autotune mode.
"""
).strip()
if VERIFY:
raise RuntimeError(msg)
else:
log.warning(msg)
else:
raise
except AssertionError as e:
raise AssertionError(
f"Incorrect result from choice {choices.index(choice)} {choice}\n\n{e}"
)
self.log_results(choices[0].name, input_nodes, timings)
best_choice = builtins.min(timings, key=timings.__getitem__)
return choices.index(best_choice)
counters["inductor"]["select_algorithm_autotune"] += 1
key = [x.hash_key() for x in choices] + [self.key_of(x) for x in input_nodes]
return choices[self.lookup(key, autotune)].output_node()
@classmethod
def make_benchmark_fn(
cls,
choices,
input_nodes,
layout,
):
example_inputs = [cls.benchmark_example_value(x) for x in input_nodes]
example_inputs_extern = list(example_inputs)
for i in range(len(example_inputs)):
if input_nodes[i].get_layout().offset != 0:
offset = V.graph.sizevars.size_hint(input_nodes[i].get_layout().offset)
data = example_inputs_extern[i]
example_inputs_extern[i] = torch.as_strided(
data, data.size(), data.stride(), offset
)
out = cls.benchmark_example_value(layout)
out_extern = torch.as_strided(
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
)
if VERIFY:
choices[0].to_callable()(*example_inputs_extern, out=out_extern)
expected = out_extern.clone()
def benchmark(algo, is_extern):
out.zero_()
if is_extern:
result = do_bench(lambda: algo(*example_inputs_extern, out=out_extern))
else:
result = do_bench(lambda: algo(*example_inputs, out=out))
if VERIFY:
torch.testing.assert_close(out_extern, expected, **VERIFY)
torch.cuda.synchronize() # shake out any CUDA errors
return result
return benchmark
@staticmethod
def log_results(name, input_nodes, timings):
if not PRINT_AUTOTUNE:
return
sizes = ", ".join(
[
"x".join(map(str, V.graph.sizevars.size_hints(n.get_size())))
for n in input_nodes
]
)
top_k = sorted(timings, key=timings.__getitem__)[:10]
best = top_k[0]
best_time = timings[best][0]
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
for choice in top_k:
result = timings[choice]
sys.stderr.write(
f" {choice.name} {result[0]:.4f}s {best_time/result[0]:.1%}\n"
)
@staticmethod
def benchmark_example_value(node):
"""
Convert an ir.Buffer into a concrete torch.Tensor we can use for
benchmarking.
"""
if isinstance(node, ir.Layout):
node = ir.Buffer("fake", node)
return rand_strided(
V.graph.sizevars.size_hints(node.get_size()),
V.graph.sizevars.size_hints(node.get_stride()),
device=node.get_device(),
dtype=node.get_dtype(),
extra_size=V.graph.sizevars.size_hint(node.get_layout().offset),
)
@staticmethod
def key_of(node):
"""
Extract the pieces of an ir.Buffer that we should invalidate cached
autotuning results on.
"""
sizevars = V.graph.sizevars
return (
node.get_device().type,
str(node.get_dtype()),
*sizevars.size_hints(node.get_size()),
*sizevars.size_hints(node.get_stride()),
sizevars.size_hint(node.get_layout().offset),
)
autotune_select_algorithm = AlgorithmSelectorCache(__name__)
def realize_inputs(*args):
if len(args) == 1:
return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
return [realize_inputs(x) for x in args]

View File

@ -367,6 +367,9 @@ class SizeVarAllocator(object):
out = sympy_subs(sympy.expand(expr), self.var_to_val)
return int(out)
def size_hints(self, exprs: List[Expr]) -> int:
return tuple(self.size_hint(x) for x in exprs)
def _lru_cache(self, fn, maxsize=None):
"""
Wrapper around functools.lru_cache that clears when replacements

View File

@ -3,6 +3,5 @@ from ..utils import has_triton
if has_triton():
from .conv import _conv, conv
from .conv1x1 import _conv1x1, conv1x1
from .matmul import _matmul_out, matmul_out
__all__ = ["_conv", "conv", "_conv1x1", "conv1x1", "_matmul_out", "matmul_out"]
__all__ = ["_conv", "conv", "_conv1x1", "conv1x1"]

View File

@ -12,7 +12,6 @@ import torch
from .. import config
from ..ir import ReductionHint, TileHint
from ..triton_ops.mm_perf_model import estimate_matmul_time
from ..utils import conditional_product, dynamo_utils, has_triton
from .conv_perf_model import (
early_config_prune as conv_early_config_prune,
@ -106,8 +105,10 @@ class CachingAutotuner(KernelInterface):
exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
# set_device(current_device()) # TODO(jansel): is this needed?
grid_0, grid_1, grid_2 = grid(grid_meta)
if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
@ -398,7 +399,7 @@ def pointwise(size_hints, meta, tile_hint=None, filename=None):
if len(size_hints) == 1:
return cached_autotune([triton_config(size_hints, 1024)], meta=meta)
if len(size_hints) == 2:
if not config.triton.autotune or tile_hint == TileHint.SQUARE:
if not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE:
return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
return cached_autotune(
[
@ -412,7 +413,7 @@ def pointwise(size_hints, meta, tile_hint=None, filename=None):
filename=filename,
)
if len(size_hints) == 3:
if not config.triton.autotune:
if not config.triton.autotune_pointwise:
return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
return cached_autotune(
[
@ -448,7 +449,7 @@ def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
return cached_autotune([outer_config], meta=meta)
elif reduction_hint == ReductionHint.OUTER_TINY:
return cached_autotune([tiny_config], meta=meta)
if not config.triton.autotune:
if not config.triton.autotune_pointwise:
return cached_autotune(
[triton_config_reduction(size_hints, 32, 128)], meta=meta
)
@ -469,6 +470,15 @@ def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
raise NotImplementedError(f"size_hints: {size_hints}")
def template(num_stages, num_warps, meta, filename=None):
"""
Compile a triton template
"""
return cached_autotune(
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
)
def conv_heuristics():
configs = [
triton.Config(
@ -552,126 +562,6 @@ def conv_heuristics():
return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
def mm_heuristics():
from triton import heuristics
mm_heuristic = heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
return mm_heuristic
def mm_autotune(get_io_bound_configs=False):
from triton.ops.matmul import get_configs_io_bound
from triton.ops.matmul_perf_model import early_config_prune as mm_early_config_prune
configs = [
# basic configs for compute-bound matmuls
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=5,
num_warps=2,
),
# good for int8
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=5,
num_warps=2,
),
]
if get_io_bound_configs:
configs += get_configs_io_bound()
key = ["M", "N", "K"]
prune_configs_by = {
"early_config_prune": mm_early_config_prune,
"perf_model": estimate_matmul_time,
"top_k": 10,
}
return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
def grid(xnumel, ynumel=None, znumel=None):
"""Helper function to compute triton grids"""

View File

@ -1,274 +0,0 @@
import torch
from ..utils import has_triton
if has_triton():
import triton
import triton.language as tl
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
num_stages=5,
num_warps=2,
),
# additional configs
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=2,
num_warps=4,
),
# additional configs for K = 64
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
num_stages=1,
num_warps=2,
),
],
# + get_configs_io_bound(),
key=["M", "N", "K"],
#
# key=["M", "N", "K"],
# prune_configs_by={
# "early_config_prune": early_config_prune,
# "perf_model": estimate_matmul_time,
# "top_k": 18,
# },
)
@triton.jit
def _kernel(
A,
B,
C,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
bid = tl.program_id(2)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
A += bid * M * K
B += bid * K * N
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
C += bid * M * N
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def bmm_out(a, b, out):
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[2] == b.shape[1], "incompatible dimensions"
B, M, K = a.shape
_, _, N = b.shape
# allocates output
c = out
# accumulator types
ACC_TYPE = (
tl.float32
if a.dtype in [torch.float16, torch.bfloat16, torch.float32]
else tl.int32
)
# launch kernel
def grid(META):
return (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
META["SPLIT_K"],
B,
)
# grid = lambda META: (
# triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
# META["SPLIT_K"],
# B,
# )
# autotuner = _kernel[grid].kernel
_kernel[grid](a, b, c, M, N, K, K, 1, N, 1, N, 1, GROUP_M=8, ACC_TYPE=ACC_TYPE)
# print(autotuner.best_config)
# print(autotuner.configs_timings)

View File

@ -1,136 +0,0 @@
import torch
from ..utils import has_triton
if has_triton():
import triton
import triton.language as tl
from .autotune import mm_autotune, mm_heuristics
@mm_heuristics()
@mm_autotune(get_io_bound_configs=True)
@triton.jit
def _kernel(
A,
B,
C,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
allow_tf32: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
acc += tl.dot(a, b, allow_tf32=allow_tf32)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
class _matmul_out:
kernel = _kernel
@staticmethod
def _call(a, b, out, allow_tf32=True):
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = out
# accumulator types
ACC_TYPE = (
tl.float32
if a.dtype in [torch.float16, torch.bfloat16, torch.float32]
else tl.int32
)
# launch kernel (grid defined as using def instead of lambda to pass `make lint`)
def grid(META):
return (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
META["SPLIT_K"],
)
# grid = lambda META: (
# triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
# META["SPLIT_K"],
# )
_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
allow_tf32=allow_tf32,
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
)
@staticmethod
def forward(a, b, out, allow_tf32=True):
return _matmul_out._call(a, b, out, allow_tf32)
matmul_out = _matmul_out.forward

View File

@ -1,90 +0,0 @@
import torch
def estimate_matmul_time(
# backend, device,
num_warps,
num_stages,
A,
B,
M,
N,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
SPLIT_K,
debug=False,
**kwargs,
):
"""return estimated running time in ms
= max(compute, loading) + store"""
import triton
import triton._C.libtriton.triton as _triton
from triton.ops.matmul_perf_model import (
get_dram_gbps as get_dram_gbps,
get_tflops as get_tflops,
)
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
dtype = A.dtype
dtsize = A.element_size()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
compute_ms = total_ops / tput
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(
1, num_ctas / 32
) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(
min(1, (num_ctas - 32) / (108 - 32)), 0
) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (
active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05
) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
# loading time in ms
load_ms = total_dram / dram_bw + total_l2 / l2_bw
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram / store_bw
else:
reduce_bw = store_bw
store_ms = store_c_dram / reduce_bw
# c.zero_()
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
store_ms += zero_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(
f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, "
f"loading time: {load_ms}ms, store time: {store_ms}ms, "
f"Activate CTAs: {active_cta_ratio*100}%"
)
return total_time_ms

View File

@ -1,6 +1,7 @@
import collections
import contextlib
import functools
import logging
import math
import operator
import os
@ -17,9 +18,11 @@ import sympy
import torch
from torch.fx.immutable_collections import immutable_dict, immutable_list
from . import config
from . import config, config as inductor_config
from .cuda_properties import get_device_capability
log = logging.getLogger(__name__)
VarRanges = Dict[sympy.Expr, sympy.Expr]
# We import torchdynamo modules indirectly to allow a future rename to torch.dynamo
@ -30,6 +33,13 @@ dynamo_optimizations = import_module(f"{config.dynamo_import}.optimizations")
dynamo_testing = import_module(f"{config.dynamo_import}.testing")
dynamo_utils = import_module(f"{config.dynamo_import}.utils")
try:
from triton.testing import do_bench
except ImportError:
def do_bench(*args, **kwargs):
raise NotImplementedError("requires Triton")
@functools.lru_cache(None)
def has_triton():
@ -432,3 +442,21 @@ class DeferredLineBase:
def __len__(self):
return len(self.line)
@functools.lru_cache(None)
def is_big_gpu(index):
cores = torch.cuda.get_device_properties(index).multi_processor_count
if cores < 80: # V100
log.warning("not enough cuda cores to use max_autotune mode")
return False
return True
def use_triton_template(layout):
return (
inductor_config.max_autotune
and layout.device.type == "cuda"
and layout.dtype in (torch.float16, torch.bfloat16, torch.float32)
and is_big_gpu(layout.device.index or 0)
)