mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[inductor] Rewrite Triton templates + epilogue fusion (retry) (#91575)
This reverts commit 94262efc7d381ace82aa74ed2f5f5ec76f8fca95 to reland #91105 / #90738. Fixes https://github.com/pytorch/torchdynamo/issues/2015 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91575 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
6912f7c564
commit
7c1c239db1
@ -139,7 +139,7 @@ def main():
|
||||
if args.verbose:
|
||||
torch._inductor.config.debug = True
|
||||
|
||||
torch._inductor.config.triton.autotune = True
|
||||
torch._inductor.config.triton.autotune_pointwise = True
|
||||
|
||||
rows = []
|
||||
for model in (MicroBenchmarks.sum,):
|
||||
|
||||
@ -118,6 +118,10 @@ REQUIRE_EVEN_HIGHER_TOLERANCE = {
|
||||
"tacotron2",
|
||||
}
|
||||
|
||||
REQUIRE_HIGHER_FP16_TOLERANCE = {
|
||||
"drq",
|
||||
}
|
||||
|
||||
REQUIRE_COSINE_TOLERACE = {
|
||||
# Just keeping it here even though its empty, if we need this in future.
|
||||
}
|
||||
@ -335,6 +339,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
||||
cosine = self.args.cosine
|
||||
# Increase the tolerance for torch allclose
|
||||
if self.args.float16 or self.args.amp:
|
||||
if name in REQUIRE_HIGHER_FP16_TOLERANCE:
|
||||
return 1e-2, cosine
|
||||
return 1e-3, cosine
|
||||
if is_training and current_device == "cuda":
|
||||
tolerance = 1e-3
|
||||
|
||||
1
setup.py
1
setup.py
@ -1156,7 +1156,6 @@ def main():
|
||||
'include/THH/generic/*.h',
|
||||
'include/sleef.h',
|
||||
"_inductor/codegen/*.h",
|
||||
"_inductor/codegen/*.j2",
|
||||
'share/cmake/ATen/*.cmake',
|
||||
'share/cmake/Caffe2/*.cmake',
|
||||
'share/cmake/Caffe2/public/*.cmake',
|
||||
|
||||
146
test/inductor/test_select_algorithm.py
Normal file
146
test/inductor/test_select_algorithm.py
Normal file
@ -0,0 +1,146 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.select_algorithm as select_algorithm
|
||||
import torch.nn.functional as F
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def patches(fn):
|
||||
def skip_cache(self, key, generate):
|
||||
return generate()
|
||||
|
||||
for patcher in [
|
||||
patch.object(dynamo_config, "log_level", logging.INFO),
|
||||
patch.object(dynamo_config, "verbose", True),
|
||||
patch.object(inductor_config, "debug", True),
|
||||
patch.object(inductor_config, "max_autotune", True),
|
||||
patch.object(inductor_config, "epilogue_fusion", True),
|
||||
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
|
||||
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
|
||||
]:
|
||||
fn = patcher(fn)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
counters.clear()
|
||||
torch.manual_seed(12345)
|
||||
assert (
|
||||
not torch.backends.cuda.matmul.allow_tf32
|
||||
), "correctness testing is allergic to tf32"
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class TestSelectAlgorithm(TestCase):
|
||||
@patches
|
||||
def test_linear_relu(self):
|
||||
@torch.compile
|
||||
def foo(input, weight, bias):
|
||||
return F.relu(F.linear(input, weight, bias))
|
||||
|
||||
foo(
|
||||
torch.randn(64, 32, device="cuda"),
|
||||
torch.randn(16, 32, device="cuda"),
|
||||
torch.randn(16, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
# It would be nice to assert this got fused into a single kernel, but that
|
||||
# only happens if we select a triton template (and not aten).
|
||||
|
||||
@patches
|
||||
def test_addmm(self):
|
||||
@torch.compile
|
||||
def foo(input, weight, bias):
|
||||
return torch.addmm(bias, input, weight)
|
||||
|
||||
foo(
|
||||
torch.randn(20, 33, device="cuda"),
|
||||
torch.randn(33, 16, device="cuda"),
|
||||
torch.randn(20, 16, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
def test_mm(self):
|
||||
@torch.compile
|
||||
def foo(a, b):
|
||||
return torch.mm(a, b)
|
||||
|
||||
foo(
|
||||
torch.randn(8, 32, device="cuda"),
|
||||
torch.randn(32, 8, device="cuda"),
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
def test_mm_skip(self):
|
||||
@torch.compile
|
||||
def foo(a, b):
|
||||
return torch.mm(a, b)
|
||||
|
||||
foo(
|
||||
torch.randn(8, 32, device="cuda", dtype=torch.float64),
|
||||
torch.randn(32, 8, device="cuda", dtype=torch.float64),
|
||||
)
|
||||
# float64 not supported by tl.dot()
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
|
||||
|
||||
@patches
|
||||
def test_bmm(self):
|
||||
@torch.compile
|
||||
def foo(a, b):
|
||||
return torch.bmm(a, b)
|
||||
|
||||
foo(
|
||||
torch.randn(2, 8, 32, device="cuda"),
|
||||
torch.randn(2, 32, 8, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
def test_mm_not_even_k(self):
|
||||
@torch.compile
|
||||
def foo(a, b):
|
||||
return torch.mm(a, b)
|
||||
|
||||
foo(
|
||||
torch.randn(11, 22, device="cuda"),
|
||||
torch.randn(22, 33, device="cuda"),
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
def test_baddbmm(self):
|
||||
@torch.compile
|
||||
def foo(a, b, c):
|
||||
return torch.baddbmm(c, a, b)
|
||||
|
||||
foo(
|
||||
torch.randn(2, 8, 32, device="cuda"),
|
||||
torch.randn(2, 32, 8, device="cuda"),
|
||||
torch.randn(2, 1, 8, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
|
||||
run_tests()
|
||||
@ -79,7 +79,7 @@ requires_multigpu = functools.partial(
|
||||
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
|
||||
)
|
||||
|
||||
torch._inductor.config.triton.autotune = False # too slow
|
||||
torch._inductor.config.triton.autotune_pointwise = False # too slow
|
||||
|
||||
|
||||
# For OneDNN bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw,
|
||||
@ -2505,76 +2505,6 @@ class CommonTemplate:
|
||||
self.assertEqual(a.stride(), c.stride())
|
||||
self.assertEqual(c.stride()[2], 1)
|
||||
|
||||
@requires_cuda()
|
||||
@patch.object(config.triton, "convolution", "triton")
|
||||
@patch.object(config.triton, "dense_indexing", "True")
|
||||
def test_triton_conv(self):
|
||||
@torch._dynamo.optimize("inductor", nopython=True)
|
||||
def triton_conv(
|
||||
x,
|
||||
w,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
):
|
||||
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
||||
return y
|
||||
|
||||
stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
|
||||
dtype = torch.float32
|
||||
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
|
||||
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
|
||||
bias = torch.randn((32), dtype=dtype, device=self.device)
|
||||
|
||||
y = triton_conv(x, w, bias, stride, padding, dilation, groups)
|
||||
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
||||
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
|
||||
|
||||
@requires_cuda()
|
||||
@patch.object(config.triton, "convolution", "autotune")
|
||||
@patch.object(config.triton, "dense_indexing", "True")
|
||||
def test_conv_autotune(self):
|
||||
@torch._dynamo.optimize("inductor", nopython=True)
|
||||
def triton_conv(
|
||||
x,
|
||||
w,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
):
|
||||
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
||||
return y
|
||||
|
||||
stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
|
||||
dtype = torch.float32
|
||||
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
|
||||
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
|
||||
bias = torch.randn((32), dtype=dtype, device=self.device)
|
||||
|
||||
y = triton_conv(x, w, bias, stride, padding, dilation, groups)
|
||||
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
|
||||
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))
|
||||
|
||||
@patch.object(config.triton, "mm", "triton")
|
||||
def test_triton_mm2(self):
|
||||
@torch._dynamo.optimize("inductor", nopython=True)
|
||||
def fn(x, y):
|
||||
return torch.relu(torch.mm(x, y))
|
||||
|
||||
N = 1024
|
||||
a = torch.randn([N, N], device=self.device, dtype=torch.float32)
|
||||
b = torch.randn([N, N], device=self.device, dtype=torch.float32)
|
||||
c1 = torch.relu(torch.mm(a, b))
|
||||
torch._inductor.metrics.reset()
|
||||
c = fn(a, b)
|
||||
assert torch.allclose(c1, c, atol=1e-3, rtol=1e-3)
|
||||
if self.device == "cuda":
|
||||
assert torch._inductor.metrics.generated_kernel_count == 1
|
||||
|
||||
def test_std(self):
|
||||
def fn(x):
|
||||
return (
|
||||
@ -4560,12 +4490,6 @@ class CommonTemplate:
|
||||
)
|
||||
expected_kernel = 0
|
||||
# codegen mm kernel from template
|
||||
if config.triton.mm != "aten" and self.device == "cuda":
|
||||
expected_kernel = 1
|
||||
if config.triton.mm == "autotune":
|
||||
self.assertLessEqual(
|
||||
torch._inductor.metrics.generated_kernel_count, expected_kernel
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._inductor.metrics.generated_kernel_count, expected_kernel
|
||||
)
|
||||
@ -4641,15 +4565,6 @@ class CommonTemplate:
|
||||
result.sum().backward()
|
||||
|
||||
expected_kernel = 4
|
||||
if config.triton.mm != "aten" and self.device == "cuda":
|
||||
# fwd: 2 * (mm+dropout) kernels = 2 kernels
|
||||
# bwd: dropout + (mm) + 2 * (mm+dropout) kernels = 4 kernels
|
||||
# expect 2 + 4 = 6 kernels
|
||||
expected_kernel = 6
|
||||
if config.triton.mm == "autotune":
|
||||
self.assertLessEqual(
|
||||
torch._inductor.metrics.generated_kernel_count, expected_kernel
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._inductor.metrics.generated_kernel_count, expected_kernel
|
||||
)
|
||||
@ -4979,7 +4894,6 @@ class CommonTemplate:
|
||||
inputs = (inputs[1], inputs[0])
|
||||
self.assertTrue(same(opt(*inputs), fn(*inputs)))
|
||||
|
||||
@patch.object(config.triton, "mm", "aten")
|
||||
def test_list_clearing(self):
|
||||
|
||||
if self.device == "cpu":
|
||||
@ -5685,7 +5599,7 @@ if HAS_CUDA:
|
||||
res = opt_mod(*args)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
@patch.object(config.triton, "autotune", True)
|
||||
@patch.object(config.triton, "autotune_pointwise", True)
|
||||
def test_inplace_add_alpha_autotune(self):
|
||||
def fn(x, y):
|
||||
aten.add_.Tensor(x, y, alpha=0.55)
|
||||
@ -5703,7 +5617,7 @@ if HAS_CUDA:
|
||||
fn_compiled([x3, y])
|
||||
assert same(x2, x3)
|
||||
|
||||
@patch.object(config.triton, "autotune", True)
|
||||
@patch.object(config.triton, "autotune_pointwise", True)
|
||||
def test_inplace_buffer_autotune(self):
|
||||
def foo(x, y, z):
|
||||
a = x @ y
|
||||
|
||||
@ -242,8 +242,12 @@ def requires_static_shapes(fn):
|
||||
return _fn
|
||||
|
||||
|
||||
def rand_strided(size, stride, dtype=torch.float32, device="cpu"):
|
||||
needed_size = sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
|
||||
def rand_strided(size, stride, dtype=torch.float32, device="cpu", extra_size=0):
|
||||
needed_size = (
|
||||
sum((shape - 1) * stride for shape, stride in zip(size, stride))
|
||||
+ 1
|
||||
+ extra_size
|
||||
)
|
||||
if dtype.is_floating_point:
|
||||
buffer = torch.randn(needed_size, dtype=dtype, device=device)
|
||||
else:
|
||||
|
||||
@ -3,6 +3,7 @@ import dataclasses
|
||||
import functools
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -60,6 +61,36 @@ def cache_dir():
|
||||
)
|
||||
|
||||
|
||||
class DiskCache:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _subdir():
|
||||
subdir = os.path.join(cache_dir(), "cached_tunings")
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
return subdir
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(4096)
|
||||
def _read_file(path):
|
||||
with open(path, "r") as fd:
|
||||
return json.loads(fd.read())
|
||||
|
||||
def __init__(self, unique_name):
|
||||
super().__init__()
|
||||
self.unique_name = unique_name
|
||||
|
||||
def lookup(self, key: Any, generate: Callable[[], Any]):
|
||||
"""
|
||||
Check if we have already generated key, if not call generate()
|
||||
to populate the cache.
|
||||
"""
|
||||
path = os.path.join(self._subdir(), code_hash(self.unique_name + repr(key)))
|
||||
if not os.path.exists(path):
|
||||
value = generate()
|
||||
write_atomic(path, json.dumps(value))
|
||||
return self._read_file(path)
|
||||
|
||||
|
||||
def get_lock_dir():
|
||||
lock_dir = os.path.join(cache_dir(), "locks")
|
||||
if not os.path.exists(lock_dir):
|
||||
@ -88,14 +119,18 @@ def write(source_code, ext, extra=""):
|
||||
if not os.path.exists(subdir):
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
if not os.path.exists(path):
|
||||
# use a temp file for thread safety
|
||||
fd, tmp_path = tempfile.mkstemp(dir=subdir)
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(source_code)
|
||||
os.rename(tmp_path, path)
|
||||
write_atomic(path, source_code)
|
||||
return basename, path
|
||||
|
||||
|
||||
def write_atomic(path: str, source_code: str):
|
||||
# use a temp file for thread safety
|
||||
fd, tmp_path = tempfile.mkstemp(dir=os.path.dirname(path))
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(source_code)
|
||||
os.rename(tmp_path, path)
|
||||
|
||||
|
||||
def cpp_compiler():
|
||||
if isinstance(config.cpp.cxx, (list, tuple)):
|
||||
search = tuple(config.cpp.cxx)
|
||||
|
||||
@ -3,7 +3,6 @@ import builtins
|
||||
import torch
|
||||
|
||||
from .. import config, triton_ops
|
||||
from ..triton_ops.autotune import mm_autotune, mm_heuristics
|
||||
from ..utils import dynamo_testing
|
||||
from ..virtualized import V
|
||||
|
||||
@ -141,79 +140,6 @@ def tuned_conv(
|
||||
return best_kernel
|
||||
|
||||
|
||||
def tuned_mm(
|
||||
a_shape,
|
||||
b_shape,
|
||||
a_stride,
|
||||
b_stride,
|
||||
device,
|
||||
dtype,
|
||||
adjust_triton=0.95,
|
||||
):
|
||||
"""
|
||||
Return the best kernel name given mm input size;
|
||||
Considering potential pointwise fusion of mm, we could adjust triton timing
|
||||
by multiplying adjust_triton (default=0.95)
|
||||
"""
|
||||
|
||||
sizevars = V.graph.sizevars
|
||||
a_shape = [sizevars.size_hint(s) for s in a_shape]
|
||||
b_shape = [sizevars.size_hint(s) for s in b_shape]
|
||||
a_stride = [sizevars.size_hint(s) for s in a_stride]
|
||||
b_stride = [sizevars.size_hint(s) for s in b_stride]
|
||||
a = rand_strided(a_shape, a_stride, device=device, dtype=dtype)
|
||||
b = rand_strided(b_shape, b_stride, device=device, dtype=dtype)
|
||||
c = torch.empty((a_shape[0], b_shape[1]), device=device, dtype=dtype)
|
||||
id_args = [
|
||||
*a_shape,
|
||||
*b_shape,
|
||||
]
|
||||
use_cuda = a.is_cuda
|
||||
|
||||
# gen_key
|
||||
key = tuple(id_args)
|
||||
key = ("mm",) + key
|
||||
|
||||
# candidate kernels
|
||||
kernels = ["aten.mm.out"]
|
||||
if use_cuda:
|
||||
kernels += ["triton_ops.matmul_out"]
|
||||
# if only one choice, return that kernel
|
||||
if len(kernels) == 1:
|
||||
kernel = kernels[0]
|
||||
return kernel
|
||||
timings = {}
|
||||
if key not in autotune.cache:
|
||||
# bench_start = time.time()
|
||||
for kernel in kernels:
|
||||
runnable_kernel = str2func(kernel)
|
||||
if "triton_ops" in kernel:
|
||||
run_args = (a, b, c)
|
||||
run_kwargs = {}
|
||||
inner_kernel = str2func(
|
||||
kernel.replace("matmul_out", "_matmul_out") + ".kernel"
|
||||
)
|
||||
inner_kernel.kernel_decorators = []
|
||||
# fix SPLIT_K = 1 for fusable kernels
|
||||
mm_heuristics()(mm_autotune(get_io_bound_configs=False)(inner_kernel))
|
||||
else:
|
||||
run_args = (a, b)
|
||||
run_kwargs = {"out": c}
|
||||
timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
|
||||
if "triton_ops" in kernel:
|
||||
timing = timing * adjust_triton
|
||||
timings[kernel] = timing
|
||||
# bench_end = time.time()
|
||||
# bench_time = bench_end - bench_start
|
||||
autotune.cache[key] = builtins.min(timings, key=timings.get)
|
||||
if config.debug:
|
||||
print("for key = ", key)
|
||||
print("timing", timings)
|
||||
print("best_kernel", autotune.cache[key])
|
||||
best_kernel = autotune.cache[key]
|
||||
return best_kernel
|
||||
|
||||
|
||||
def tuned_conv_layout(
|
||||
kernel,
|
||||
x_shape,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import collections
|
||||
import contextlib
|
||||
import itertools
|
||||
import logging
|
||||
@ -200,13 +199,29 @@ class KernelArgs:
|
||||
return odict[name]
|
||||
|
||||
def __init__(self, sizevars=None):
|
||||
self.input_buffers = collections.OrderedDict()
|
||||
self.output_buffers = collections.OrderedDict()
|
||||
self.inplace_buffers = collections.OrderedDict()
|
||||
self.sizevars = sizevars or collections.OrderedDict()
|
||||
self.input_buffers = dict()
|
||||
self.output_buffers = dict()
|
||||
self.inplace_buffers = dict()
|
||||
self.sizevars = sizevars or dict()
|
||||
|
||||
def __repr__(self):
|
||||
return "KernelArgs({})".format(
|
||||
", ".join(
|
||||
map(
|
||||
repr,
|
||||
[
|
||||
self.input_buffers,
|
||||
self.output_buffers,
|
||||
self.inplace_buffers,
|
||||
self.sizevars,
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def input(self, name):
|
||||
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
||||
if V.graph.scheduler:
|
||||
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
||||
assert name not in V.graph.removed_buffers, name
|
||||
if name in self.output_buffers:
|
||||
return self.output_buffers[name]
|
||||
@ -217,7 +232,8 @@ class KernelArgs:
|
||||
return self._lookup("in_ptr", self.input_buffers, name)
|
||||
|
||||
def output(self, name):
|
||||
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
||||
if V.graph.scheduler:
|
||||
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
||||
assert name not in V.graph.removed_buffers, name
|
||||
if name in self.inplace_buffers:
|
||||
return self.inplace_buffers[name].inner_name
|
||||
@ -428,7 +444,10 @@ class CSE:
|
||||
var = self.newvar()
|
||||
self.cache[expr] = var
|
||||
if write:
|
||||
V.kernel.current_node.codegen_originating_info(buffer, only_once=True)
|
||||
if V.kernel.current_node:
|
||||
V.kernel.current_node.codegen_originating_info(
|
||||
buffer, only_once=True
|
||||
)
|
||||
buffer.writeline(f"{self.prefix}{var} = {expr}{self.suffix}")
|
||||
return self.cache[expr]
|
||||
|
||||
@ -552,8 +571,9 @@ class Kernel(CodeGen):
|
||||
self.store_buffer_names.add(name)
|
||||
if mode is None:
|
||||
self.cse.store_cache[name] = value
|
||||
for other_name in self.current_node.get_mutations():
|
||||
self.cse.store_cache[other_name] = value
|
||||
if self.current_node:
|
||||
for other_name in self.current_node.get_mutations():
|
||||
self.cse.store_cache[other_name] = value
|
||||
if name not in V.graph.removed_buffers:
|
||||
return self.store(name, index, value, mode=mode)
|
||||
|
||||
@ -571,7 +591,8 @@ class Kernel(CodeGen):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
V.graph.scheduler.remove_kernel_local_buffers()
|
||||
if V.graph.scheduler:
|
||||
V.graph.scheduler.remove_kernel_local_buffers()
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def rename_indexing(self, index) -> sympy.Expr:
|
||||
|
||||
@ -416,14 +416,17 @@ class IterationRangesRoot(IterationRanges):
|
||||
self.nodes[expr] = node
|
||||
return self.nodes[expr]
|
||||
|
||||
def construct(self, lengths: List[sympy.Expr]):
|
||||
def construct_entries(self, lengths: List[sympy.Expr]):
|
||||
divisor = sympy.Integer(1)
|
||||
itervars = []
|
||||
for length in reversed(lengths):
|
||||
itervars.append(self.lookup(divisor, length).symbol())
|
||||
itervars.append(self.lookup(divisor, length))
|
||||
divisor = divisor * length
|
||||
return list(reversed(itervars))
|
||||
|
||||
def construct(self, lengths: List[sympy.Expr]):
|
||||
return [e.symbol() for e in self.construct_entries(lengths)]
|
||||
|
||||
def vars_and_sizes(self, index: sympy.Expr):
|
||||
"""Figure out vars from this tree used in index"""
|
||||
nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
|
||||
@ -497,6 +500,11 @@ class IterationRangesEntry(IterationRanges):
|
||||
self.codegen = functools.lru_cache(None)(self._codegen)
|
||||
self.expr = expr
|
||||
|
||||
def set_name(self, name):
|
||||
self.codegen = lambda: name
|
||||
self.codegen.cache_clear = lambda: None
|
||||
self.name = name
|
||||
|
||||
def cache_clear(self):
|
||||
self.codegen.cache_clear()
|
||||
|
||||
@ -732,6 +740,7 @@ class TritonKernel(Kernel):
|
||||
*,
|
||||
copy_shape=None,
|
||||
dense_indexing=False,
|
||||
override_mask=None,
|
||||
):
|
||||
"""
|
||||
Compute the index and mask to pass to tl.load() or tl.store()
|
||||
@ -742,7 +751,9 @@ class TritonKernel(Kernel):
|
||||
|
||||
mask_vars: Set[str] = set()
|
||||
for var in index_vars:
|
||||
if var.name.startswith("tmp"):
|
||||
if override_mask:
|
||||
pass
|
||||
elif var.name.startswith("tmp"):
|
||||
# indirect indexing
|
||||
cse_var = self.cse.varname_map[var.name]
|
||||
mask_vars.update(cse_var.mask_vars)
|
||||
@ -770,7 +781,10 @@ class TritonKernel(Kernel):
|
||||
dense_mask_vars.add(f"{tree.prefix}mask")
|
||||
|
||||
if (need_dense and not have_dense) or isinstance(index, sympy.Integer):
|
||||
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
|
||||
if copy_shape:
|
||||
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
|
||||
else:
|
||||
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
|
||||
if isinstance(index, sympy.Integer):
|
||||
return index_str, set(), "None"
|
||||
else:
|
||||
@ -779,6 +793,9 @@ class TritonKernel(Kernel):
|
||||
mask_vars = dense_mask_vars
|
||||
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
|
||||
|
||||
if override_mask:
|
||||
mask_vars = {override_mask}
|
||||
|
||||
if self._load_mask:
|
||||
mask_vars.add(self._load_mask)
|
||||
|
||||
@ -829,6 +846,7 @@ class TritonKernel(Kernel):
|
||||
ep = ", eviction_policy='evict_last'"
|
||||
else:
|
||||
ep = ""
|
||||
|
||||
# "other" below is a workaround for https://github.com/openai/triton/issues/737
|
||||
# for bool, even though it's likely subject to the same bug, setting `other` leads
|
||||
# to LLVM errors so we are skipping it for now
|
||||
@ -1106,6 +1124,13 @@ class TritonKernel(Kernel):
|
||||
wrapper.writeline("''')")
|
||||
return wrapper.getvalue()
|
||||
|
||||
def codegen_template_wrapper(self, src_code):
|
||||
wrapper = IndentedBuffer()
|
||||
wrapper.writeline("async_compile.triton('''")
|
||||
wrapper.splice(src_code, strip=True)
|
||||
wrapper.writeline("''')")
|
||||
return wrapper.getvalue()
|
||||
|
||||
def codegen_static_numels(self, code):
|
||||
"""
|
||||
We get a small speedup from hard coding numels if they are static.
|
||||
@ -1187,6 +1212,9 @@ class TritonScheduling:
|
||||
if not (numel1 == numel2 and rnumel1 == rnumel2):
|
||||
return False
|
||||
|
||||
if node1.is_template():
|
||||
return True # skip checks for compatible tiling
|
||||
|
||||
# check for a bad combined tiling
|
||||
tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
|
||||
tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
|
||||
@ -1212,7 +1240,10 @@ class TritonScheduling:
|
||||
for n in node1.get_nodes()
|
||||
):
|
||||
return False
|
||||
if config.triton.tiling_prevents_reduction_fusion:
|
||||
if (
|
||||
config.triton.tiling_prevents_reduction_fusion
|
||||
and not node1.is_template()
|
||||
):
|
||||
return self.select_tiling(node1.get_nodes(), numel1) in (
|
||||
(numel1, 1),
|
||||
(numel2, rnumel2, 1),
|
||||
@ -1348,8 +1379,13 @@ class TritonScheduling:
|
||||
index_vars = kernel.split_and_set_ranges(node.get_ranges())
|
||||
node.codegen(index_vars)
|
||||
|
||||
wrapper = V.graph.wrapper_code
|
||||
src_code = kernel.codegen_kernel()
|
||||
kernel_name = self.define_kernel(src_code, node_schedule)
|
||||
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
|
||||
self.scheduler.free_buffers()
|
||||
|
||||
def define_kernel(self, src_code, node_schedule):
|
||||
wrapper = V.graph.wrapper_code
|
||||
if src_code in wrapper.kernels:
|
||||
kernel_name = wrapper.kernels[src_code]
|
||||
else:
|
||||
@ -1366,7 +1402,25 @@ class TritonScheduling:
|
||||
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
|
||||
src_code = src_code.replace("#pragma CMT", "#")
|
||||
wrapper.define_kernel(kernel_name, src_code)
|
||||
kernel.call_kernel(wrapper, kernel_name)
|
||||
return kernel_name
|
||||
|
||||
def codegen_template(self, template_node, epilogue_nodes):
|
||||
"""
|
||||
Codegen a triton template
|
||||
"""
|
||||
_, (numel, rnumel) = template_node.group
|
||||
assert rnumel == 1
|
||||
kernel, render = template_node.node.make_kernel_render(template_node.node)
|
||||
with kernel:
|
||||
for node in [template_node, *epilogue_nodes]:
|
||||
node.mark_run()
|
||||
render() # warmup run to get the args right
|
||||
for node in epilogue_nodes:
|
||||
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
|
||||
|
||||
src_code = kernel.codegen_template_wrapper(render())
|
||||
kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
|
||||
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
|
||||
self.scheduler.free_buffers()
|
||||
|
||||
def codegen_sync(self):
|
||||
|
||||
@ -1,181 +0,0 @@
|
||||
|
||||
@conv_heuristics()
|
||||
@triton.jit
|
||||
def {{kernel_name}}(
|
||||
{% for i in template_inout_argdefs %}
|
||||
{{i}},
|
||||
{% endfor %}
|
||||
# stride of tensor
|
||||
stride_xn,
|
||||
stride_xc,
|
||||
stride_xh,
|
||||
stride_xw,
|
||||
stride_wn,
|
||||
stride_wc,
|
||||
stride_wh,
|
||||
stride_ww,
|
||||
stride_yn,
|
||||
stride_yc,
|
||||
stride_yh,
|
||||
stride_yw,
|
||||
stride_biasn,
|
||||
# Tensor dimensions
|
||||
BATCH,
|
||||
IN_C,
|
||||
IN_H,
|
||||
IN_W,
|
||||
KERNEL_N,
|
||||
KERNEL_H,
|
||||
KERNEL_W,
|
||||
OUT_H,
|
||||
OUT_W,
|
||||
# parameters of conv
|
||||
stride_h,
|
||||
stride_w,
|
||||
padding_h,
|
||||
padding_w,
|
||||
dilation_h,
|
||||
dilation_w,
|
||||
output_padding_h,
|
||||
output_padding_w,
|
||||
groups: tl.constexpr,
|
||||
# pointer inc for x
|
||||
delta_x_ptr,
|
||||
# fusable kernels args
|
||||
{% for i in extra_argdefs %}
|
||||
{{i}},
|
||||
{% endfor %}
|
||||
# Metaparameters
|
||||
ACC_TYPE: tl.constexpr,
|
||||
CONV1X1_NHWC: tl.constexpr,
|
||||
# blocks in different dimension
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
# reduction tiling parameter for matmul
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of y it should compute.
|
||||
pid_nhw = tl.program_id(0)
|
||||
pid_k = tl.program_id(1)
|
||||
|
||||
# offset for output y
|
||||
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_y_n = off_y_nhw // (OUT_H * OUT_W)
|
||||
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
|
||||
off_y_h = off_y_hw // OUT_W
|
||||
off_y_w = off_y_hw % OUT_W
|
||||
|
||||
# offset for the initial ptr for x
|
||||
off_x_n = off_y_n
|
||||
off_x_h = off_y_h * stride_h - padding_h
|
||||
off_x_w = off_y_w * stride_w - padding_w
|
||||
off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
|
||||
off_x_crs = tl.arange(0, BLOCK_K)
|
||||
|
||||
CRS = IN_C * KERNEL_H * KERNEL_W
|
||||
# load inc ptr of x, upade x_ptrs
|
||||
if not CONV1X1_NHWC:
|
||||
delta_x_ptrs = delta_x_ptr + off_x_crs
|
||||
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
|
||||
else:
|
||||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
|
||||
|
||||
mask_x = (
|
||||
(off_x_n < BATCH)
|
||||
& (off_x_h >= 0)
|
||||
& (off_x_h < IN_H)
|
||||
& (off_x_w >= 0)
|
||||
& (off_x_w < IN_W)
|
||||
)[:, None] & (off_x_crs < CRS)[None, :]
|
||||
|
||||
# offset for the inital ptr for w
|
||||
off_w_crs = tl.arange(0, BLOCK_K)
|
||||
off_w_k = off_y_k
|
||||
w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
|
||||
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
|
||||
|
||||
# ------ load x ------
|
||||
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
||||
# ------ load w ------
|
||||
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# allocate accumulator
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for crs in range(0, CRS, BLOCK_K):
|
||||
|
||||
# ------ matrix multiplication ------
|
||||
acc += tl.dot(matrix_x, matrix_w)
|
||||
# ------ update ptrs ------
|
||||
w_ptrs += BLOCK_K
|
||||
# load inc ptr of x, upade x_ptrs
|
||||
if not CONV1X1_NHWC:
|
||||
delta_x_ptrs += BLOCK_K
|
||||
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
|
||||
else:
|
||||
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
x_ptrs += BLOCK_K
|
||||
|
||||
mask_x = (
|
||||
(off_x_n < BATCH)
|
||||
& (off_x_h >= 0)
|
||||
& (off_x_h < IN_H)
|
||||
& (off_x_w >= 0)
|
||||
& (off_x_w < IN_W)
|
||||
)[:, None] & (off_x_crs < CRS)[None, :]
|
||||
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
|
||||
# ------ prefetch ------
|
||||
# ------ load x ------
|
||||
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
||||
# ------ load w ------
|
||||
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
|
||||
acc = acc.to({{out_def}}.dtype.element_ty)
|
||||
|
||||
{% if keep_store %}
|
||||
# rematerialize -- this saves some registers
|
||||
# offset for output y
|
||||
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_y_n = off_y_nhw // (OUT_H * OUT_W)
|
||||
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
|
||||
# consider output padding
|
||||
off_y_h = off_y_hw // OUT_W + output_padding_h
|
||||
off_y_w = off_y_hw % OUT_W + output_padding_w
|
||||
|
||||
# y ptrs in the block of [BLOCK_M, BLOCK_N]
|
||||
y_ptrs = (
|
||||
{{out_def}}
|
||||
+ off_y_n[:, None] * stride_yn
|
||||
+ off_y_h[:, None] * stride_yh
|
||||
+ off_y_w[:, None] * stride_yw
|
||||
+ off_y_k[None, :] * stride_yc
|
||||
)
|
||||
|
||||
# out-of-bounds check
|
||||
mask_y = (
|
||||
(off_y_n < BATCH)[:, None]
|
||||
& (off_y_h < OUT_H + output_padding_h)[:, None]
|
||||
& (off_y_w < OUT_W + output_padding_w)[:, None]
|
||||
& (off_y_k < KERNEL_N)[None, :]
|
||||
)
|
||||
tl.store(y_ptrs, acc, mask=mask_y)
|
||||
{% endif %}
|
||||
|
||||
{% if pointwise_code %}
|
||||
{{ pointwise_code | indent(4, true) }}
|
||||
{#
|
||||
z = tl.load(z_ptrs, mask=mask_z)
|
||||
acc += z
|
||||
#}
|
||||
{% endif %}
|
||||
|
||||
return
|
||||
@ -1,200 +0,0 @@
|
||||
|
||||
@conv_heuristics()
|
||||
@triton.jit
|
||||
def {{kernel_name}}(
|
||||
{% for i in template_inout_argdefs %}
|
||||
{{i}},
|
||||
{% endfor %}
|
||||
# stride of tensor
|
||||
stride_xn,
|
||||
stride_xc,
|
||||
stride_xh,
|
||||
stride_xw,
|
||||
stride_wn,
|
||||
stride_wc,
|
||||
stride_wh,
|
||||
stride_ww,
|
||||
stride_yn,
|
||||
stride_yc,
|
||||
stride_yh,
|
||||
stride_yw,
|
||||
stride_biasn,
|
||||
# Tensor dimensions
|
||||
BATCH,
|
||||
IN_C,
|
||||
IN_H,
|
||||
IN_W,
|
||||
KERNEL_N,
|
||||
KERNEL_H,
|
||||
KERNEL_W,
|
||||
OUT_H,
|
||||
OUT_W,
|
||||
# parameters of conv
|
||||
stride_h,
|
||||
stride_w,
|
||||
padding_h,
|
||||
padding_w,
|
||||
dilation_h,
|
||||
dilation_w,
|
||||
output_padding_h,
|
||||
output_padding_w,
|
||||
groups,
|
||||
# pointer inc for x
|
||||
delta_xh_ptr,
|
||||
delta_xw_ptr,
|
||||
delta_xc_ptr,
|
||||
# fusable kernels args
|
||||
{% for i in extra_argdefs %}
|
||||
{{i}},
|
||||
{% endfor %}
|
||||
# Metaparameters
|
||||
ACC_TYPE: tl.constexpr,
|
||||
CONV1X1_NHWC: tl.constexpr,
|
||||
# blocks in different dimension
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
# reduction tiling parameter for matmul
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of y it should compute.
|
||||
pid_nhw = tl.program_id(0)
|
||||
pid_k = tl.program_id(1)
|
||||
|
||||
# offset for output y
|
||||
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_y_n = off_y_nhw // (OUT_H * OUT_W)
|
||||
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
|
||||
off_y_h = off_y_hw // OUT_W + output_padding_h
|
||||
off_y_w = off_y_hw % OUT_W + output_padding_w
|
||||
|
||||
# offset for the initial ptr for x
|
||||
off_x_n = off_y_n
|
||||
off_x_h = off_y_h * stride_h - padding_h
|
||||
off_x_w = off_y_w * stride_w - padding_w
|
||||
off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
|
||||
off_x_crs = tl.arange(0, BLOCK_K)
|
||||
|
||||
CRS = IN_C * KERNEL_H * KERNEL_W
|
||||
# load inc ptr of x, upade x_ptrs
|
||||
if not CONV1X1_NHWC:
|
||||
delta_xh_ptrs = delta_xh_ptr + off_x_crs
|
||||
delta_xw_ptrs = delta_xw_ptr + off_x_crs
|
||||
delta_xc_ptrs = delta_xc_ptr + off_x_crs
|
||||
delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
off_x_crs_unpacked = (
|
||||
delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
|
||||
)
|
||||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
|
||||
else:
|
||||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
|
||||
delta_xh = 0
|
||||
delta_xw = 0
|
||||
|
||||
mask_x = (
|
||||
(off_x_n < BATCH)[:, None]
|
||||
& (off_x_crs < CRS)[None, :]
|
||||
& (off_x_h[:, None] + delta_xh[None, :] >= 0)
|
||||
& (off_x_h[:, None] + delta_xh[None, :] < IN_H)
|
||||
& (off_x_w[:, None] + delta_xw[None, :] >= 0)
|
||||
& (off_x_w[:, None] + delta_xw[None, :] < IN_W)
|
||||
)
|
||||
|
||||
# offset for the inital ptr for w
|
||||
off_w_crs = tl.arange(0, BLOCK_K)
|
||||
off_w_k = off_y_k
|
||||
w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
|
||||
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
|
||||
|
||||
# ------ load x ------
|
||||
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
||||
# ------ load w ------
|
||||
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# allocate accumulator
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for crs in range(0, CRS, BLOCK_K):
|
||||
|
||||
# ------ matrix multiplication ------
|
||||
acc += tl.dot(matrix_x, matrix_w)
|
||||
# ------ update ptrs ------
|
||||
w_ptrs += BLOCK_K
|
||||
# load inc ptr of x, upade x_ptrs
|
||||
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
if not CONV1X1_NHWC:
|
||||
delta_xh_ptrs += BLOCK_K
|
||||
delta_xw_ptrs += BLOCK_K
|
||||
delta_xc_ptrs += BLOCK_K
|
||||
delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
|
||||
off_x_crs_unpacked = (
|
||||
delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
|
||||
)
|
||||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
|
||||
else:
|
||||
x_ptrs += BLOCK_K
|
||||
|
||||
mask_x = (
|
||||
(off_x_n < BATCH)[:, None]
|
||||
& (off_x_crs < CRS)[None, :]
|
||||
& (off_x_h[:, None] + delta_xh[None, :] >= 0)
|
||||
& (off_x_h[:, None] + delta_xh[None, :] < IN_H)
|
||||
& (off_x_w[:, None] + delta_xw[None, :] >= 0)
|
||||
& (off_x_w[:, None] + delta_xw[None, :] < IN_W)
|
||||
)
|
||||
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
|
||||
# ------ prefetch ------
|
||||
# ------ load x ------
|
||||
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
||||
# ------ load w ------
|
||||
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
|
||||
acc = acc.to({{out_def}}.dtype.element_ty)
|
||||
|
||||
{% if keep_store %}
|
||||
# rematerialize -- this saves some registers
|
||||
# offset for output y
|
||||
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_y_n = off_y_nhw // (OUT_H * OUT_W)
|
||||
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
|
||||
# consider output padding
|
||||
off_y_h = off_y_hw // OUT_W + output_padding_h
|
||||
off_y_w = off_y_hw % OUT_W + output_padding_w
|
||||
|
||||
# y ptrs in the block of [BLOCK_M, BLOCK_N]
|
||||
y_ptrs = (
|
||||
{{out_def}}
|
||||
+ off_y_n[:, None] * stride_yn
|
||||
+ off_y_h[:, None] * stride_yh
|
||||
+ off_y_w[:, None] * stride_yw
|
||||
+ off_y_k[None, :] * stride_yc
|
||||
)
|
||||
|
||||
# out-of-bounds check
|
||||
mask_y = (
|
||||
(off_y_n < BATCH)[:, None]
|
||||
& (off_y_h < OUT_H + output_padding_h)[:, None]
|
||||
& (off_y_w < OUT_W + output_padding_w)[:, None]
|
||||
& (off_y_k < KERNEL_N)[None, :]
|
||||
)
|
||||
tl.store(y_ptrs, acc, mask=mask_y)
|
||||
{% endif %}
|
||||
|
||||
{% if pointwise_code %}
|
||||
{{ pointwise_code | indent(4, true) }}
|
||||
{#
|
||||
z = tl.load(z_ptrs, mask=mask_z)
|
||||
acc += z
|
||||
#}
|
||||
{% endif %}
|
||||
|
||||
return
|
||||
@ -1,80 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
{# from triton.ops.matmul import get_configs_io_bound #}
|
||||
|
||||
@mm_autotune()
|
||||
@mm_heuristics()
|
||||
@triton.jit
|
||||
def {{kernel_name}}(
|
||||
{% for i in template_inout_argdefs %}
|
||||
{{i}},
|
||||
{% endfor %}
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
# fusable kernels args
|
||||
{% for i in extra_argdefs %}
|
||||
{{i}},
|
||||
{% endfor %}
|
||||
allow_tf32: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A_ptrs = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B_ptrs = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for k in range(K, 0, -BLOCK_K * SPLIT_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A_ptrs)
|
||||
b = tl.load(B_ptrs)
|
||||
else:
|
||||
a = tl.load(A_ptrs, mask=rk[None, :] < k, other=0.0)
|
||||
b = tl.load(B_ptrs, mask=rk[:, None] < k, other=0.0)
|
||||
acc += tl.dot(a, b, allow_tf32=allow_tf32)
|
||||
A_ptrs += BLOCK_K * SPLIT_K * stride_ak
|
||||
B_ptrs += BLOCK_K * SPLIT_K * stride_bk
|
||||
acc = acc.to({{out_def}}.dtype.element_ty)
|
||||
|
||||
{% if keep_store %}
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C_ptrs = {{out_def}} + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
tl.store(C_ptrs, acc, mask=mask)
|
||||
{% endif %}
|
||||
|
||||
{% if pointwise_code %}
|
||||
{{ pointwise_code | indent(4, true) }}
|
||||
{% endif %}
|
||||
@ -1,351 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import sympy
|
||||
|
||||
from .. import config, ir
|
||||
from ..virtualized import V
|
||||
from .common import IndentedBuffer
|
||||
from .triton import TritonKernel
|
||||
|
||||
log = logging.getLogger((__name__))
|
||||
template_dict = {ir.Convolution: "triton_conv", ir.MatrixMultiply: "triton_mm"}
|
||||
|
||||
|
||||
class TritonTemplateKernel(TritonKernel):
|
||||
def __init__(self, node: ir.ExternKernel, *groups):
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
||||
|
||||
self.node = node
|
||||
self.template_name = template_dict[type(node)]
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(os.path.dirname(__file__)),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
undefined=StrictUndefined,
|
||||
)
|
||||
pid_cache = {}
|
||||
if isinstance(node, ir.Convolution):
|
||||
pid_cache = {
|
||||
"tl.program_id(0)": "pid_nhw",
|
||||
"tl.program_id(1)": "pid_k",
|
||||
}
|
||||
self.map_args()
|
||||
KERNEL_H = self.args_dict["KERNEL_H"]
|
||||
KERNEL_W = self.args_dict["KERNEL_W"]
|
||||
padding_h = self.args_dict["padding_h"]
|
||||
padding_w = self.args_dict["padding_w"]
|
||||
if ((KERNEL_H == "1" and KERNEL_W == "1")) or (
|
||||
(padding_h == "0") and (padding_w == "0")
|
||||
):
|
||||
self.template_name += "_delta_x"
|
||||
else:
|
||||
self.template_name += "_delta_x_hwc"
|
||||
elif isinstance(node, ir.MatrixMultiply):
|
||||
pid_cache = {
|
||||
"tl.program_id(0)": "pid_m",
|
||||
"tl.program_id(1)": "pid_n",
|
||||
}
|
||||
|
||||
self.template = env.get_template(self.template_name + ".j2")
|
||||
super(TritonTemplateKernel, self).__init__(*groups, pid_cache=pid_cache)
|
||||
|
||||
def rename_vars(self):
|
||||
for k, v in self.inout_dict.items():
|
||||
self.args.output_buffers[v] = k
|
||||
if isinstance(self.node, ir.Convolution):
|
||||
self.cse.store_cache[self.inout_dict["y"]] = "acc"
|
||||
elif isinstance(self.node, ir.MatrixMultiply):
|
||||
self.cse.store_cache[self.inout_dict["C"]] = "acc"
|
||||
|
||||
def assign_block_numel(self):
|
||||
code = IndentedBuffer()
|
||||
if isinstance(self.node, ir.Convolution):
|
||||
code.writeline("XBLOCK: tl.constexpr = BLOCK_M")
|
||||
code.writeline("YBLOCK: tl.constexpr = BLOCK_N")
|
||||
code.writeline(
|
||||
"xnumel = BATCH * (OUT_H + 2 * output_padding_h) * (OUT_W + 2 * output_padding_w)"
|
||||
)
|
||||
code.writeline("ynumel = KERNEL_N")
|
||||
elif isinstance(self.node, ir.MatrixMultiply):
|
||||
code.writeline("XBLOCK: tl.constexpr = BLOCK_M")
|
||||
code.writeline("YBLOCK: tl.constexpr = BLOCK_N")
|
||||
code.writeline("xnumel = M")
|
||||
code.writeline("ynumel = N")
|
||||
|
||||
return code
|
||||
|
||||
def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=True):
|
||||
# use dense_indexing for TritonTemplateKernel to avoid map::at error
|
||||
return super().indexing(
|
||||
index, copy_shape=copy_shape, dense_indexing=dense_indexing
|
||||
)
|
||||
|
||||
def codegen_body(
|
||||
self, name, fuse, could_remove_kernel_buf, kernel_buf_replace_name
|
||||
):
|
||||
"""
|
||||
put render_variables into the template
|
||||
to generate the final code
|
||||
"""
|
||||
# get extra_argdefs from fusable triton kernels
|
||||
self.extra_argdefs = []
|
||||
self.extra_call_args = []
|
||||
argdefs, call_args, _ = self.args.python_argdefs()
|
||||
# add extra args if it is different from
|
||||
# current TritonTemplateKernel args
|
||||
for (argdef, call_arg) in zip(argdefs, call_args):
|
||||
if (
|
||||
argdef not in self.inout_dict.keys()
|
||||
and argdef not in self.args_dict.keys()
|
||||
):
|
||||
self.extra_argdefs.append(argdef)
|
||||
self.extra_call_args.append(call_arg)
|
||||
|
||||
if could_remove_kernel_buf:
|
||||
if isinstance(self.node, ir.Convolution):
|
||||
self.inout_dict.pop("y")
|
||||
elif isinstance(self.node, ir.MatrixMultiply):
|
||||
self.inout_dict.pop("C")
|
||||
self.template_inout_argdefs = list(self.inout_dict.keys())
|
||||
|
||||
if kernel_buf_replace_name is not None:
|
||||
idx = self.extra_call_args.index(kernel_buf_replace_name)
|
||||
kernel_buf_replace_def = self.extra_argdefs[idx]
|
||||
|
||||
super().codegen_body()
|
||||
self.pointwise_code = IndentedBuffer()
|
||||
self.pointwise_code.splice(self.assign_block_numel())
|
||||
self.pointwise_code.splice(self.body)
|
||||
render_dict = {}
|
||||
render_dict["kernel_name"] = name
|
||||
render_dict["template_inout_argdefs"] = self.template_inout_argdefs
|
||||
render_dict["extra_argdefs"] = self.extra_argdefs
|
||||
render_dict["pointwise_code"] = self.pointwise_code.getvalue() if fuse else None
|
||||
render_dict["keep_store"] = not could_remove_kernel_buf
|
||||
render_dict["out_def"] = (
|
||||
self.out_def() if not could_remove_kernel_buf else kernel_buf_replace_def
|
||||
)
|
||||
self.body = self.template.render(render_dict) + "\n"
|
||||
|
||||
def out_def(self):
|
||||
if isinstance(self.node, ir.Convolution):
|
||||
return "y"
|
||||
elif isinstance(self.node, ir.MatrixMultiply):
|
||||
return "C"
|
||||
|
||||
def codegen_kernel(
|
||||
self,
|
||||
name=None,
|
||||
fuse=False,
|
||||
could_remove_kernel_buf=False,
|
||||
kernel_buf_replace_name=None,
|
||||
):
|
||||
|
||||
code = IndentedBuffer()
|
||||
|
||||
self.codegen_body(name, fuse, could_remove_kernel_buf, kernel_buf_replace_name)
|
||||
code.splice(self.body)
|
||||
|
||||
if name is not None:
|
||||
return code.getvalue()
|
||||
|
||||
wrapper = IndentedBuffer()
|
||||
wrapper.writeline("TritonCodeCache.load('''")
|
||||
wrapper.splice(code.getvalue(), strip=True)
|
||||
wrapper.writeline("''').kernel")
|
||||
|
||||
return wrapper.getvalue()
|
||||
|
||||
def map_args(self):
|
||||
"""
|
||||
map the constant args or
|
||||
kernel[grid](..., IN_C, IN_H, IN_W, strides,...)
|
||||
"""
|
||||
(
|
||||
self.inout_dict,
|
||||
self.args_dict,
|
||||
self.const_dict,
|
||||
self.other_dict,
|
||||
) = self.node.map_args()
|
||||
|
||||
def precompute(self, wrapper, kernel_name):
|
||||
"""
|
||||
some triton kernels needs host precompute tensor
|
||||
for example, triton_conv needs precompte delta_x_ptr
|
||||
"""
|
||||
if isinstance(self.node, ir.Convolution):
|
||||
if self.const_dict["CONV1X1_NHWC"] == "False":
|
||||
IN_C = self.args_dict["IN_C"]
|
||||
KERNEL_H = self.args_dict["KERNEL_H"]
|
||||
KERNEL_W = self.args_dict["KERNEL_W"]
|
||||
dilation_h = self.args_dict["dilation_h"]
|
||||
dilation_w = self.args_dict["dilation_w"]
|
||||
stride_wc = self.args_dict["stride_wc"]
|
||||
stride_wh = self.args_dict["stride_wh"]
|
||||
stride_ww = self.args_dict["stride_ww"]
|
||||
stride_xc = self.args_dict["stride_xc"]
|
||||
stride_xh = self.args_dict["stride_xh"]
|
||||
stride_xw = self.args_dict["stride_xw"]
|
||||
device = self.other_dict["device"]
|
||||
if self.template_name == "triton_conv_delta_x":
|
||||
assert "delta_x_ptr" not in self.args_dict.keys()
|
||||
self.args_dict["delta_x_ptr"] = f"delta_x_{kernel_name}"
|
||||
wrapper.writeline(
|
||||
f"from {config.inductor_import}.triton_ops import _conv"
|
||||
)
|
||||
wrapper.writeline(
|
||||
f"delta_x_{kernel_name} = _conv._delta_x_ptr("
|
||||
f"{IN_C}, {KERNEL_H}, {KERNEL_W}, "
|
||||
f"{dilation_h}, {dilation_w}, "
|
||||
f"{stride_wc}, {stride_wh}, {stride_ww}, "
|
||||
f"{stride_xc}, {stride_xh}, {stride_xw}, {device})"
|
||||
)
|
||||
# triton_conv_delta_x_hwc
|
||||
else:
|
||||
assert "delta_xh_ptr" not in self.args_dict.keys()
|
||||
assert "delta_xw_ptr" not in self.args_dict.keys()
|
||||
assert "delta_xc_ptr" not in self.args_dict.keys()
|
||||
self.args_dict["delta_xh_ptr"] = f"delta_xh_{kernel_name}"
|
||||
self.args_dict["delta_xw_ptr"] = f"delta_xw_{kernel_name}"
|
||||
self.args_dict["delta_xc_ptr"] = f"delta_xc_{kernel_name}"
|
||||
wrapper.writeline(
|
||||
f"from {config.inductor_import}.triton_ops import _conv"
|
||||
)
|
||||
wrapper.writeline(
|
||||
f"delta_xh_{kernel_name}, delta_xw_{kernel_name}, delta_xc_{kernel_name}"
|
||||
f" = _conv._delta_x_ptr_hwc("
|
||||
f"{IN_C}, {KERNEL_H}, {KERNEL_W}, "
|
||||
f"{dilation_h}, {dilation_w}, "
|
||||
f"{stride_wc}, {stride_wh}, {stride_ww}, "
|
||||
f"{stride_xc}, {stride_xh}, {stride_xw}, {device})"
|
||||
)
|
||||
|
||||
# else, delta_x_ptr is None
|
||||
else:
|
||||
assert "delta_x_ptr" not in self.args_dict.keys()
|
||||
self.args_dict["delta_x_ptr"] = "None"
|
||||
return
|
||||
|
||||
def gen_grid(self, name):
|
||||
code = IndentedBuffer()
|
||||
if isinstance(self.node, ir.Convolution):
|
||||
BATCH = self.args_dict["BATCH"]
|
||||
OUT_H = self.args_dict["OUT_H"]
|
||||
OUT_W = self.args_dict["OUT_W"]
|
||||
KERNEL_N = self.args_dict["KERNEL_N"]
|
||||
code.splice(
|
||||
f"""
|
||||
def grid_{name}(META):
|
||||
return (
|
||||
triton.cdiv({BATCH} * {OUT_H} * {OUT_W}, META["BLOCK_M"]),
|
||||
triton.cdiv({KERNEL_N}, META["BLOCK_N"]),
|
||||
)
|
||||
"""
|
||||
)
|
||||
if isinstance(self.node, ir.MatrixMultiply):
|
||||
M = self.args_dict["M"]
|
||||
N = self.args_dict["N"]
|
||||
code.splice(
|
||||
f"""
|
||||
def grid_{name}(META):
|
||||
return (
|
||||
triton.cdiv({M}, META["BLOCK_M"]) * triton.cdiv({N}, META["BLOCK_N"]),
|
||||
META["SPLIT_K"],
|
||||
)
|
||||
"""
|
||||
)
|
||||
return code.getvalue()
|
||||
|
||||
def call_kernel(self, wrapper, name: str):
|
||||
# gen code to call kernel
|
||||
# e.g.
|
||||
# def grid(META):
|
||||
# return (...)
|
||||
# kernel1[grid](arg0, arg1, ...)
|
||||
extra_args = ", ".join(self.extra_call_args)
|
||||
self_args = ", ".join({**self.inout_dict, **self.args_dict}.values())
|
||||
self_const_kwargs = ", ".join(f"{k}={v}" for k, v in self.const_dict.items())
|
||||
args = self_args + (
|
||||
", " + extra_args if extra_args and len(extra_args) > 0 else ""
|
||||
)
|
||||
args_kwargs = args + ", " + self_const_kwargs
|
||||
lines = self.gen_grid(name).split("\n")
|
||||
for l in lines:
|
||||
wrapper.writeline(l)
|
||||
wrapper.writeline(f"{name}[grid_{name}]({args_kwargs})")
|
||||
|
||||
|
||||
def should_use_template(node: ir.ExternKernel):
|
||||
template_kernels = [ir.Convolution, ir.MatrixMultiply]
|
||||
if type(node) in template_kernels and ir.is_triton(node.get_device()):
|
||||
if isinstance(node, ir.Convolution):
|
||||
return node.kernel != "aten.convolution"
|
||||
elif isinstance(node, ir.MatrixMultiply):
|
||||
return node.kernel != "aten.mm.out"
|
||||
return False
|
||||
|
||||
|
||||
def template_can_fuse(snode1, snode2):
|
||||
assert snode1.is_template()
|
||||
if snode1.group != snode2.group:
|
||||
return False
|
||||
tiling = snode1.get_nodes()[0].node.get_template_tiling()
|
||||
for node in snode2.get_nodes():
|
||||
if not TritonKernel.is_compatible(tiling, node.get_ranges()):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def template_codegen(scheduler, scheduler_node, epilogue):
|
||||
"""
|
||||
codegen function for triton templates
|
||||
scheduler: Scheduler
|
||||
scheduler_node: ExternKernelSchedulerNode
|
||||
"""
|
||||
log.debug("template_codegen: %s -- %s", scheduler_node, epilogue)
|
||||
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, groups = scheduler_node.group
|
||||
|
||||
with TritonTemplateKernel(
|
||||
scheduler_node.node, *scheduler_node.node.get_template_tiling()
|
||||
) as kernel:
|
||||
# map const args/ shape/ strides to kernel args
|
||||
kernel.map_args()
|
||||
# set self.args name to match the TritonTemplateKernel's args names
|
||||
kernel.rename_vars()
|
||||
# scheduler.pop_group will keep iterating all reachable fusable SchedulerNodes
|
||||
assert type(kernel.node) in template_dict.keys()
|
||||
|
||||
kernel.store_buffer_names.add(scheduler_node.get_name())
|
||||
|
||||
for node in epilogue:
|
||||
node.mark_run()
|
||||
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
|
||||
|
||||
could_remove_kernel_buf = (
|
||||
kernel.args.output_buffers[scheduler_node.get_name()] == "REMOVED"
|
||||
)
|
||||
kernel_buf_replace_name = None
|
||||
if could_remove_kernel_buf:
|
||||
for node in epilogue:
|
||||
if not kernel.args.is_removed(node.get_name()):
|
||||
kernel_buf_replace_name = node.get_name()
|
||||
break
|
||||
assert kernel_buf_replace_name is not None
|
||||
|
||||
kernel_name = "triton_template_" + wrapper.next_kernel_suffix()
|
||||
# code gen kernel
|
||||
wrapper.header.splice(
|
||||
kernel.codegen_kernel(
|
||||
kernel_name,
|
||||
bool(epilogue),
|
||||
could_remove_kernel_buf,
|
||||
kernel_buf_replace_name,
|
||||
)
|
||||
)
|
||||
# gen precompute tensor (like delta_x_ptr) if needed
|
||||
kernel.precompute(wrapper, kernel_name)
|
||||
# code gen call to kernel
|
||||
kernel.call_kernel(wrapper, kernel_name)
|
||||
@ -272,6 +272,7 @@ class WrapperCodeGen(CodeGen):
|
||||
import random
|
||||
from torch import empty_strided, as_strided, device
|
||||
from {codecache.__name__} import AsyncCompile
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
|
||||
aten = torch.ops.aten
|
||||
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
||||
@ -299,19 +300,6 @@ class WrapperCodeGen(CodeGen):
|
||||
"""
|
||||
)
|
||||
|
||||
if config.triton.mm != "aten":
|
||||
self.header.splice(
|
||||
f"""
|
||||
from {config.inductor_import}.triton_ops.autotune import mm_heuristics
|
||||
from {config.inductor_import}.triton_ops.autotune import mm_autotune
|
||||
"""
|
||||
)
|
||||
|
||||
if config.triton.use_bmm:
|
||||
self.header.writeline(
|
||||
f"from {config.inductor_import}.triton_ops.batched_matmul import bmm_out as triton_bmm_out"
|
||||
)
|
||||
|
||||
self.write_prefix()
|
||||
|
||||
for name, value in V.graph.constants.items():
|
||||
@ -325,6 +313,21 @@ class WrapperCodeGen(CodeGen):
|
||||
self.write_get_cuda_stream
|
||||
)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def add_import_once(line):
|
||||
self.header.writeline(line)
|
||||
|
||||
self.add_import_once = add_import_once
|
||||
self._metas = {}
|
||||
|
||||
def add_meta_once(self, meta):
|
||||
meta = repr(meta)
|
||||
if meta not in self._metas:
|
||||
var = f"meta{len(self._metas)}"
|
||||
self._metas[meta] = var
|
||||
self.header.writeline(f"{var} = {meta}")
|
||||
return self._metas[meta]
|
||||
|
||||
@cache_on_self
|
||||
def get_output_refs(self):
|
||||
return [x.codegen_reference() for x in V.graph.graph_outputs]
|
||||
|
||||
@ -36,6 +36,15 @@ inplace_buffers = True
|
||||
# codegen benchmark harness
|
||||
benchmark_harness = True
|
||||
|
||||
# fuse pointwise into templates
|
||||
epilogue_fusion = False
|
||||
|
||||
# do epilogue fusions before other fusions
|
||||
epilogue_fusion_first = False
|
||||
|
||||
# enable slow autotuning passes to select algorithms
|
||||
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
|
||||
|
||||
# control store vs recompute heuristic
|
||||
# For fanouts, rematearialization can lead to exponential blowup. So, have
|
||||
# smaller threshold
|
||||
@ -143,12 +152,9 @@ class triton:
|
||||
# Synchronize after every kernel launch, to help pinpoint bugs
|
||||
debug_sync_kernel = False
|
||||
|
||||
# choose conv backend, "aten" or "triton" or "autotune"
|
||||
# choose conv backend, "aten" or "triton"
|
||||
convolution = "aten"
|
||||
|
||||
# choose mm backend, "aten" or "triton" or "autotune"
|
||||
mm = "aten"
|
||||
|
||||
# Always load full blocks (rather than broadcasting inside the block)
|
||||
# Set default as True because otherwise will encouter `map::at` error
|
||||
# in triton if loading from 1-dim tensor using 2-dim pointer offset
|
||||
@ -159,10 +165,9 @@ class triton:
|
||||
# limit tiling dimensions
|
||||
max_tiles = 2
|
||||
|
||||
# use triton.autotune?
|
||||
autotune = True
|
||||
|
||||
use_bmm = False
|
||||
# use triton.autotune for pointwise ops with complex layouts
|
||||
# this should only be disabled for debugging/testing
|
||||
autotune_pointwise = True
|
||||
|
||||
# should we stop a fusion to allow better tiling?
|
||||
tiling_prevents_pointwise_fusion = True
|
||||
@ -209,33 +214,25 @@ class trace:
|
||||
|
||||
class InductorConfigContext:
|
||||
static_memory: bool
|
||||
matmul_tune: str
|
||||
matmul_padding: bool
|
||||
triton_autotune: bool
|
||||
triton_bmm: bool
|
||||
triton_mm: str
|
||||
max_autotune: bool
|
||||
triton_convolution: str
|
||||
rematerialize_threshold: int
|
||||
rematerialize_acc_threshold: int
|
||||
|
||||
def _save(self):
|
||||
self.static_memory = triton.cudagraphs
|
||||
self.matmul_tune = triton.mm
|
||||
self.matmul_padding = shape_padding
|
||||
self.triton_autotune = triton.autotune
|
||||
self.triton_bmm = triton.use_bmm
|
||||
self.triton_mm = triton.mm
|
||||
self.max_autotune = max_autotune
|
||||
self.triton_convolution = triton.convolution
|
||||
self.rematerialize_threshold = realize_reads_threshold
|
||||
self.rematerialize_acc_threshold = realize_acc_reads_threshold
|
||||
|
||||
def _apply(self):
|
||||
global shape_padding, realize_reads_threshold, realize_acc_reads_threshold, max_autotune
|
||||
triton.cudagraphs = self.static_memory
|
||||
triton.mm = self.matmul_tune
|
||||
shape_padding = self.matmul_padding
|
||||
triton.autotune = self.triton_autotune
|
||||
triton.use_bmm = self.triton_bmm
|
||||
triton.mm = self.triton_mm
|
||||
max_autotune = self.max_autotune
|
||||
triton.convolution = self.triton_convolution
|
||||
realize_reads_threshold = self.rematerialize_threshold
|
||||
realize_acc_reads_threshold = self.rematerialize_acc_threshold
|
||||
@ -254,11 +251,7 @@ class InductorConfigContext:
|
||||
self.static_memory = True
|
||||
|
||||
def max_autotune():
|
||||
self.static_memory = False
|
||||
self.matmul_padding = True
|
||||
self.triton_convolution = "autotune"
|
||||
self.triton_mm = "autotune"
|
||||
self.matmul_padding = True
|
||||
self.max_autotune = True
|
||||
|
||||
modes = {
|
||||
x.__name__.replace("_", "-"): x
|
||||
|
||||
@ -27,12 +27,10 @@ from torch.fx.passes.tools_common import legalize_graph
|
||||
from . import config, ir # noqa: F811, this is needed
|
||||
from .scheduler import (
|
||||
BaseSchedulerNode,
|
||||
ExternKernelSchedulerNode,
|
||||
FusedSchedulerNode,
|
||||
NopKernelSchedulerNode,
|
||||
OutputNode,
|
||||
SchedulerNode,
|
||||
TemplateSchedulerNode,
|
||||
)
|
||||
from .utils import dynamo_config, dynamo_debug_utils, dynamo_utils
|
||||
from .virtualized import V
|
||||
@ -110,10 +108,10 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
||||
group: Any = None
|
||||
# create call_function node for each Buffer and Kernel
|
||||
for snode in snodes:
|
||||
if isinstance(snode, ExternKernelSchedulerNode):
|
||||
if snode.is_extern():
|
||||
node_type = "extern"
|
||||
group = node_type
|
||||
elif isinstance(snode, TemplateSchedulerNode):
|
||||
elif snode.is_template():
|
||||
node_type = "template"
|
||||
group = node_type
|
||||
elif isinstance(snode, NopKernelSchedulerNode):
|
||||
|
||||
@ -165,14 +165,6 @@ def pad_dim(x, padded_length, dim):
|
||||
|
||||
@register_decomposition([aten.addmm])
|
||||
def addmm(input, mat1, mat2, *, beta=1, alpha=1):
|
||||
if config.triton.mm != "aten":
|
||||
out = torch.mm(mat1, mat2)
|
||||
if not isinstance(alpha, numbers.Number) or alpha != 1:
|
||||
out = out * alpha
|
||||
if not isinstance(beta, numbers.Number) or beta != 1:
|
||||
input = input * beta
|
||||
return input + out
|
||||
|
||||
if (
|
||||
config.shape_padding
|
||||
and check_device(mat1, mat2)
|
||||
|
||||
@ -120,6 +120,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.name = "GraphLowering"
|
||||
self._can_use_cpp_wrapper = config.cpp_wrapper
|
||||
self.graph_id = graph_id
|
||||
self.scheduler = None
|
||||
|
||||
def get_dtype(self, buffer_name: str):
|
||||
if buffer_name in self.constants:
|
||||
@ -175,10 +176,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
|
||||
if isinstance(buffer, ir.ExternKernel):
|
||||
if not isinstance(
|
||||
buffer,
|
||||
(ir.MatrixMultiply, ir.BatchMatrixMultiply, ir.MatrixMultiplyAdd),
|
||||
):
|
||||
if not getattr(buffer, "cpp_kernel", False):
|
||||
self.disable_cpp_wrapper("ExternKernel")
|
||||
|
||||
def register_buffer(self, buffer: ir.ComputedBuffer):
|
||||
@ -296,6 +294,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
out = lowerings[target](*args, **kwargs)
|
||||
return out
|
||||
except Exception as e:
|
||||
log.exception("Error from lowering")
|
||||
raise LoweringException(e, target, args, kwargs) from e
|
||||
|
||||
def get_attr(self, target, args, kwargs):
|
||||
@ -471,6 +470,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.init_wrapper_code()
|
||||
|
||||
self.scheduler = Scheduler(self.buffers)
|
||||
assert self.scheduler is not None # mypy can't figure this out
|
||||
self.scheduler.codegen()
|
||||
assert self.wrapper_code is not None
|
||||
return self.wrapper_code.generate()
|
||||
|
||||
@ -1578,6 +1578,9 @@ class Constant(BaseConstant):
|
||||
|
||||
return loader
|
||||
|
||||
def realize(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class IndexingConstant(BaseConstant):
|
||||
@ -1698,6 +1701,24 @@ class Layout(IRNode):
|
||||
class FixedLayout(Layout):
|
||||
"""A Tensor layout we cannot change"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
size: List[Expr],
|
||||
stride: List[Expr] = None,
|
||||
offset: Expr = Integer(0),
|
||||
):
|
||||
if stride is None:
|
||||
stride = FlexibleLayout.contiguous_strides(size)
|
||||
super().__init__(
|
||||
device,
|
||||
dtype,
|
||||
size,
|
||||
stride,
|
||||
offset,
|
||||
)
|
||||
|
||||
def make_indexer(self):
|
||||
"""A closure containing math to read a given element"""
|
||||
|
||||
@ -2250,6 +2271,58 @@ class ComputedBuffer(Buffer):
|
||||
return self.data.constant_to_device(device)
|
||||
|
||||
|
||||
class TemplateBuffer(Buffer):
|
||||
"""
|
||||
Represents a Triton (in the futurue other type) of template operator
|
||||
that we can fuse an epilogue onto.
|
||||
"""
|
||||
|
||||
def __init__(self, layout, inputs, make_kernel_render):
|
||||
super().__init__(name=None, layout=layout)
|
||||
self.inputs = InputsKernel.unwrap_storage(inputs)
|
||||
self.make_kernel_render = make_kernel_render
|
||||
self.name = V.graph.register_buffer(self)
|
||||
|
||||
def get_read_writes(self):
|
||||
return self.normalized_read_writes()
|
||||
|
||||
@cache_on_self
|
||||
def normalized_read_writes(self):
|
||||
name = self.get_name()
|
||||
indexer = self.layout.make_indexer()
|
||||
|
||||
def dummy(index, rindex):
|
||||
assert len(rindex) == 0
|
||||
return ops.store(name, indexer(index), "fake")
|
||||
|
||||
deps = dependencies.extract_read_writes(
|
||||
dummy, self.get_size(), (), normalize=True
|
||||
)
|
||||
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
|
||||
return deps
|
||||
|
||||
def get_reduction_size(self):
|
||||
return 1
|
||||
|
||||
def get_reduction_type(self):
|
||||
return None
|
||||
|
||||
def is_no_op(self):
|
||||
return False
|
||||
|
||||
def should_allocate(self):
|
||||
return True
|
||||
|
||||
def simplify_and_reorder(self):
|
||||
return (
|
||||
(
|
||||
self.get_size(),
|
||||
(),
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InputsKernel(Buffer):
|
||||
inputs: List[Buffer]
|
||||
@ -2688,12 +2761,24 @@ class ExternKernelOut(ExternKernel):
|
||||
self.cpp_kernel,
|
||||
)
|
||||
|
||||
def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None):
|
||||
def __init__(
|
||||
self,
|
||||
layout,
|
||||
inputs,
|
||||
constant_args=(),
|
||||
kwargs=None,
|
||||
output_view=None,
|
||||
kernel=None,
|
||||
cpp_kernel=None,
|
||||
):
|
||||
super().__init__(
|
||||
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
|
||||
)
|
||||
self.output_view = output_view
|
||||
self.name = V.graph.register_buffer(self)
|
||||
if kernel is not None:
|
||||
self.kernel = kernel
|
||||
self.cpp_kernel = cpp_kernel
|
||||
|
||||
def should_allocate(self):
|
||||
return True
|
||||
@ -2784,198 +2869,6 @@ class IndexPutFallback(ExternKernel):
|
||||
self.name = V.graph.register_buffer(self)
|
||||
|
||||
|
||||
class MatrixMultiply(ExternKernelOut):
|
||||
kernel = "aten.mm.out"
|
||||
cpp_kernel = "at::mm_out"
|
||||
|
||||
def __init__(
|
||||
self, layout, inputs, constant_args=(), output_view=None, kernel="aten.mm.out"
|
||||
):
|
||||
super().__init__(layout, inputs, constant_args, output_view)
|
||||
self.kernel = kernel
|
||||
|
||||
@classmethod
|
||||
def create(cls, a, b):
|
||||
*m, k1 = a.get_size()
|
||||
k2, n = b.get_size()
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
a = cls.realize_input(a)
|
||||
b = cls.realize_input(b)
|
||||
if len(m) != 1 and not a.get_layout().is_contiguous():
|
||||
a = cls.copy_input(a)
|
||||
else:
|
||||
a = cls.require_stride1(a)
|
||||
b = cls.require_stride1(b)
|
||||
|
||||
# choose runtime kernel
|
||||
config_mm = config.triton.mm
|
||||
# default kernel is aten
|
||||
kernel = "aten.mm.out"
|
||||
if config_mm == "aten":
|
||||
kernel = "aten.mm.out"
|
||||
elif config_mm == "triton" and a.get_device().type == "cuda":
|
||||
kernel = "triton_ops.matmul_out"
|
||||
elif config_mm == "autotune":
|
||||
from .codegen.autotuner import tuned_mm
|
||||
|
||||
kernel = tuned_mm(
|
||||
a.get_size(),
|
||||
b.get_size(),
|
||||
a.get_stride(),
|
||||
b.get_stride(),
|
||||
a.get_device(),
|
||||
a.get_dtype(),
|
||||
)
|
||||
|
||||
return MatrixMultiply(
|
||||
layout=FlexibleLayout(
|
||||
device=a.get_device(),
|
||||
dtype=a.get_dtype(),
|
||||
size=list(m) + [n],
|
||||
),
|
||||
inputs=[a, b],
|
||||
kernel=kernel,
|
||||
)
|
||||
|
||||
def get_template_tiling(self):
|
||||
tile1, tile2 = self.get_size()
|
||||
return (
|
||||
tile1,
|
||||
tile2,
|
||||
sympy.Integer(1),
|
||||
)
|
||||
|
||||
def map_args(self):
|
||||
# a, b
|
||||
in_args = [x.codegen_reference() for x in self.inputs]
|
||||
# const_args = self.constant_args
|
||||
inout_dict = OrderedDict(
|
||||
[
|
||||
("A", f"{in_args[0]}"),
|
||||
("B", f"{in_args[1]}"),
|
||||
("C", f"{self.get_name()}"),
|
||||
]
|
||||
)
|
||||
# batch==1 bmm->mm
|
||||
if len(self.get_stride()) == 3:
|
||||
assert self.get_size()[0] == 1
|
||||
stride_cm = self.get_stride()[1]
|
||||
stride_cn = self.get_stride()[2]
|
||||
else:
|
||||
stride_cm = self.get_stride()[0]
|
||||
stride_cn = self.get_stride()[1]
|
||||
args_dict = OrderedDict(
|
||||
[
|
||||
("M", f"{self.inputs[0].get_size()[0]}"),
|
||||
("N", f"{self.inputs[1].get_size()[1]}"),
|
||||
("K", f"{self.inputs[0].get_size()[1]}"),
|
||||
("stride_am", f"{self.inputs[0].get_stride()[0]}"),
|
||||
("stride_ak", f"{self.inputs[0].get_stride()[1]}"),
|
||||
("stride_bk", f"{self.inputs[1].get_stride()[0]}"),
|
||||
("stride_bn", f"{self.inputs[1].get_stride()[1]}"),
|
||||
("stride_cm", f"{stride_cm}"),
|
||||
("stride_cn", f"{stride_cn}"),
|
||||
]
|
||||
)
|
||||
# accumulator types
|
||||
ACC_TYPE = (
|
||||
"tl.float32"
|
||||
if self.inputs[0].get_dtype()
|
||||
in [torch.float16, torch.bfloat16, torch.float32]
|
||||
else "tl.int32"
|
||||
)
|
||||
# dict for tl.constexpr
|
||||
const_dict = OrderedDict(
|
||||
[
|
||||
("GROUP_M", "8"),
|
||||
("ACC_TYPE", ACC_TYPE),
|
||||
("allow_tf32", f"{torch.backends.cuda.matmul.allow_tf32}"),
|
||||
]
|
||||
)
|
||||
|
||||
other_dict = OrderedDict()
|
||||
|
||||
return inout_dict, args_dict, const_dict, other_dict
|
||||
|
||||
|
||||
class MatrixMultiplyAdd(ExternKernelOut):
|
||||
def __init__(self, layout, inputs, constant_args=(), kwargs=None, output_view=None):
|
||||
super().__init__(layout, inputs, constant_args, kwargs or {}, output_view)
|
||||
self.kernel = "aten.addmm.out"
|
||||
self.cpp_kernel = "at::addmm_out"
|
||||
self.ordered_kwargs_for_cpp_kernel = ["beta", "alpha"]
|
||||
|
||||
@classmethod
|
||||
def create(cls, inp, a, b, beta, alpha):
|
||||
m, k1 = a.get_size()
|
||||
k2, n = b.get_size()
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
inp = cls.realize_input(inp)
|
||||
a = cls.realize_input(a)
|
||||
b = cls.realize_input(b)
|
||||
a = cls.require_stride1(a)
|
||||
b = cls.require_stride1(b)
|
||||
return MatrixMultiplyAdd(
|
||||
layout=FlexibleLayout(
|
||||
device=a.get_device(),
|
||||
dtype=a.get_dtype(),
|
||||
size=[m] + [n],
|
||||
),
|
||||
inputs=[inp, a, b],
|
||||
kwargs={"beta": beta, "alpha": alpha},
|
||||
)
|
||||
|
||||
|
||||
class BatchMatrixMultiply(ExternKernelOut):
|
||||
kernel = "aten.bmm.out"
|
||||
cpp_kernel = "at::bmm_out"
|
||||
|
||||
def __init__(self, layout, inputs, constant_args=(), output_view=None):
|
||||
super().__init__(layout, inputs, constant_args, output_view)
|
||||
if (
|
||||
config.triton.use_bmm
|
||||
and len(inputs) > 0
|
||||
and inputs[0].get_device().type == "cuda"
|
||||
):
|
||||
self.kernel = "triton_bmm_out"
|
||||
|
||||
@classmethod
|
||||
def create(cls, a, b):
|
||||
b1, m, k1 = a.get_size()
|
||||
b2, k2, n = b.get_size()
|
||||
b3 = V.graph.sizevars.guard_equals(b1, b2)
|
||||
V.graph.sizevars.guard_equals(k1, k2)
|
||||
a = cls.require_stride1(cls.realize_input(a))
|
||||
b = cls.require_stride1(cls.realize_input(b))
|
||||
|
||||
output_layout = FlexibleLayout(
|
||||
device=a.get_device(),
|
||||
dtype=a.get_dtype(),
|
||||
size=[b3, m, n],
|
||||
).as_fixed()
|
||||
|
||||
if b3 == 1:
|
||||
# convert to normal mm
|
||||
data = MatrixMultiply(
|
||||
layout=output_layout.as_fixed(),
|
||||
inputs=[SqueezeView.create(a, dim=0), SqueezeView.create(b, dim=0)],
|
||||
)
|
||||
data.output_view = ReinterpretView(
|
||||
data,
|
||||
FlexibleLayout(
|
||||
device=a.get_device(),
|
||||
dtype=a.get_dtype(),
|
||||
size=[m, n],
|
||||
).as_fixed(),
|
||||
)
|
||||
else:
|
||||
data = BatchMatrixMultiply(
|
||||
layout=output_layout,
|
||||
inputs=[a, b],
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class DeviceCopy(ExternKernelOut):
|
||||
@classmethod
|
||||
def create(cls, x, device):
|
||||
@ -3862,7 +3755,14 @@ class StorageBox(MutableBox):
|
||||
|
||||
def realize(self):
|
||||
if isinstance(
|
||||
self.data, (ComputedBuffer, InputsKernel, InputBuffer, ReinterpretView)
|
||||
self.data,
|
||||
(
|
||||
ComputedBuffer,
|
||||
InputsKernel,
|
||||
InputBuffer,
|
||||
ReinterpretView,
|
||||
TemplateBuffer,
|
||||
),
|
||||
):
|
||||
return self.data.get_name()
|
||||
assert isinstance(self.data, (Pointwise, Reduction)), type(self.data)
|
||||
|
||||
0
torch/_inductor/kernel/__init__.py
Normal file
0
torch/_inductor/kernel/__init__.py
Normal file
126
torch/_inductor/kernel/bmm.py
Normal file
126
torch/_inductor/kernel/bmm.py
Normal file
@ -0,0 +1,126 @@
|
||||
import torch
|
||||
|
||||
from ..lowering import register_lowering
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import ceildiv as cdiv, use_triton_template
|
||||
|
||||
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def bmm_grid(b, m, n, meta):
|
||||
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
|
||||
|
||||
|
||||
bmm_template = TritonTemplate(
|
||||
name="bmm",
|
||||
grid=bmm_grid,
|
||||
source=r"""
|
||||
{{def_kernel("A", "B")}}
|
||||
M = {{size("A", -2)}}
|
||||
N = {{size("B", -1)}}
|
||||
K = {{size("A", -1)}}
|
||||
|
||||
stride_aq = {{stride("A", 0)}}
|
||||
stride_am = {{stride("A", 1)}}
|
||||
stride_ak = {{stride("A", 2)}}
|
||||
|
||||
stride_bq = {{stride("B", 0)}}
|
||||
stride_bk = {{stride("B", 1)}}
|
||||
stride_bn = {{stride("B", 2)}}
|
||||
|
||||
# based on triton.ops.matmul
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
|
||||
idx_q = tl.program_id(1) # batch dimension for BMM
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
idx_q = tl.program_id(1) # batch dimension for BMM
|
||||
idx_m = rm[:, None]
|
||||
idx_n = rn[None, :]
|
||||
mask = (idx_m < M) & (idx_n < N)
|
||||
|
||||
# inductor generates a suffix
|
||||
{{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
|
||||
""",
|
||||
)
|
||||
|
||||
aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
|
||||
aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
|
||||
|
||||
|
||||
@register_lowering(aten.bmm)
|
||||
def tuned_bmm(mat1, mat2, *, layout=None):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
|
||||
# options to tune from
|
||||
choices = [aten_bmm.bind((mat1, mat2), layout)]
|
||||
if use_triton_template(layout):
|
||||
for config in mm_configs():
|
||||
choices.append(
|
||||
bmm_template.generate(
|
||||
(mat1, mat2),
|
||||
layout,
|
||||
**mm_options(config, k, layout),
|
||||
)
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(choices, [mat1, mat2], layout)
|
||||
|
||||
|
||||
# Don't register this since it is slower than decomposing it
|
||||
# @register_lowering(aten.baddbmm)
|
||||
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
||||
|
||||
# options to tune from
|
||||
choices = [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
|
||||
if use_triton_template(layout):
|
||||
for config in mm_configs():
|
||||
choices.append(
|
||||
bmm_template.generate(
|
||||
(inp, mat1, mat2),
|
||||
layout,
|
||||
**mm_options(config, k, layout),
|
||||
prefix_args=1,
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
)
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)
|
||||
121
torch/_inductor/kernel/mm.py
Normal file
121
torch/_inductor/kernel/mm.py
Normal file
@ -0,0 +1,121 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from ..lowering import register_lowering
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import use_triton_template
|
||||
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_grid, mm_options
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
mm_template = TritonTemplate(
|
||||
name="mm",
|
||||
grid=mm_grid,
|
||||
source=r"""
|
||||
{{def_kernel("A", "B")}}
|
||||
M = {{size("A", 0)}}
|
||||
N = {{size("B", 1)}}
|
||||
K = {{size("A", 1)}}
|
||||
stride_am = {{stride("A", 0)}}
|
||||
stride_ak = {{stride("A", 1)}}
|
||||
stride_bk = {{stride("B", 0)}}
|
||||
stride_bn = {{stride("B", 1)}}
|
||||
|
||||
# based on triton.ops.matmul
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
idx_m = rm[:, None]
|
||||
idx_n = rn[None, :]
|
||||
mask = (idx_m < M) & (idx_n < N)
|
||||
|
||||
# inductor generates a suffix
|
||||
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
||||
""",
|
||||
)
|
||||
|
||||
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
|
||||
aten_addmm = ExternKernelChoice(torch.addmm, "at::addmm_out")
|
||||
|
||||
|
||||
@register_lowering(aten.mm)
|
||||
def tuned_mm(mat1, mat2, *, layout=None):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
|
||||
# options to tune from
|
||||
choices = [aten_mm.bind((mat1, mat2), layout)]
|
||||
if use_triton_template(layout):
|
||||
for config in mm_configs():
|
||||
choices.append(
|
||||
mm_template.generate(
|
||||
(mat1, mat2),
|
||||
layout,
|
||||
**mm_options(config, k, layout),
|
||||
)
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(choices, [mat1, mat2], layout)
|
||||
|
||||
|
||||
@register_lowering(aten.addmm)
|
||||
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
||||
# don't expand inp to make sure fused addmm from cublasLt is used
|
||||
if not use_triton_template(layout):
|
||||
choices = [aten_addmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
|
||||
return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)
|
||||
|
||||
# TODO this is not quite fair benchmarking because we won't use fused cublasLt addmm
|
||||
# options to tune from
|
||||
choices = [
|
||||
aten_addmm.bind((inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta)
|
||||
]
|
||||
if use_triton_template(layout):
|
||||
for config in mm_configs():
|
||||
choices.append(
|
||||
mm_template.generate(
|
||||
(inp_expanded, mat1, mat2),
|
||||
layout,
|
||||
**mm_options(config, k, layout),
|
||||
prefix_args=1,
|
||||
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
||||
)
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(choices, [inp_expanded, mat1, mat2], layout)
|
||||
125
torch/_inductor/kernel/mm_common.py
Normal file
125
torch/_inductor/kernel/mm_common.py
Normal file
@ -0,0 +1,125 @@
|
||||
import functools
|
||||
import logging
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.select_algorithm import realize_inputs
|
||||
from torch._inductor.virtualized import V
|
||||
from ..utils import ceildiv as cdiv
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def mm_configs():
|
||||
import triton
|
||||
|
||||
return [
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=3, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=5, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=5, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, num_stages=1, num_warps=2
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def mm_grid(m, n, meta):
|
||||
"""
|
||||
The CUDA grid size for matmul triton templates.
|
||||
"""
|
||||
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
|
||||
|
||||
|
||||
def acc_type(dtype):
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
return "tl.float32"
|
||||
return f"tl.{dtype}".replace("torch.", "")
|
||||
|
||||
|
||||
def mm_options(config, sym_k, layout):
|
||||
"""
|
||||
Common options to matmul triton templates.
|
||||
"""
|
||||
even_k_symbolic = (
|
||||
# it isn't worth guarding on this
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
|
||||
== config.kwargs["BLOCK_K"]
|
||||
)
|
||||
return dict(
|
||||
GROUP_M=8,
|
||||
EVEN_K=even_k_symbolic,
|
||||
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
|
||||
ACC_TYPE=acc_type(layout.dtype),
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
|
||||
def mm_args(mat1, mat2, *others, layout=None):
|
||||
"""
|
||||
Common arg processing for mm,bmm,addmm,etc
|
||||
"""
|
||||
mat1, mat2 = realize_inputs(mat1, mat2)
|
||||
*b1, m, k1 = mat1.get_size()
|
||||
*b2, k2, n = mat2.get_size()
|
||||
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
|
||||
k = V.graph.sizevars.guard_equals(k1, k2)
|
||||
if layout is None:
|
||||
from torch._inductor.ir import FixedLayout
|
||||
|
||||
layout = FixedLayout(
|
||||
mat1.get_device(),
|
||||
mat1.get_dtype(),
|
||||
[*b, m, n],
|
||||
)
|
||||
|
||||
from ..lowering import expand
|
||||
|
||||
others = [realize_inputs(expand(x, layout.size)) for x in others]
|
||||
|
||||
return [m, n, k, layout, mat1, mat2, *others]
|
||||
|
||||
|
||||
def addmm_epilogue(dtype, alpha, beta):
|
||||
def epilogue(acc, bias):
|
||||
if alpha != 1:
|
||||
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
|
||||
if beta != 1:
|
||||
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
|
||||
return V.ops.add(acc, bias)
|
||||
|
||||
return epilogue
|
||||
@ -558,6 +558,7 @@ def trunc(x):
|
||||
|
||||
@register_lowering(aten.expand, type_promotion_kind=None)
|
||||
def expand(x, sizes):
|
||||
(x,) = promote_constants([x])
|
||||
if isinstance(x, ir.BaseConstant):
|
||||
return ExpandView.create(x, tuple(sizes))
|
||||
assert isinstance(x, TensorBox)
|
||||
@ -837,21 +838,6 @@ def glu(x, dim=-1):
|
||||
return mul(a, sigmoid(b))
|
||||
|
||||
|
||||
@register_lowering(aten.mm)
|
||||
def mm(a: TensorBox, b: TensorBox):
|
||||
return TensorBox.create(ir.MatrixMultiply.create(a, b))
|
||||
|
||||
|
||||
@register_lowering(aten.addmm)
|
||||
def addmm(inp: TensorBox, a: TensorBox, b: TensorBox, beta=1, alpha=1):
|
||||
return TensorBox.create(ir.MatrixMultiplyAdd.create(inp, a, b, beta, alpha))
|
||||
|
||||
|
||||
@register_lowering(aten.bmm)
|
||||
def bmm(a: TensorBox, b: TensorBox):
|
||||
return TensorBox.create(ir.BatchMatrixMultiply.create(a, b))
|
||||
|
||||
|
||||
def register_onednn_fusion_ops():
|
||||
if torch._C.has_mkldnn:
|
||||
|
||||
@ -3731,3 +3717,20 @@ def foobar(self, *args, **kwargs):
|
||||
def _realize(x):
|
||||
x.realize()
|
||||
return clone(x)
|
||||
|
||||
|
||||
def _import_kernels():
|
||||
"""
|
||||
Need to make sure all these get registered in the lowers dict
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from . import kernel
|
||||
|
||||
for filename in sorted(os.listdir(os.path.dirname(kernel.__file__))):
|
||||
if filename.endswith(".py") and filename[0] != "_":
|
||||
importlib.import_module(f"{kernel.__name__}.{filename[:-3]}")
|
||||
|
||||
|
||||
_import_kernels()
|
||||
|
||||
@ -6,14 +6,14 @@ import logging
|
||||
import os
|
||||
import pprint
|
||||
import textwrap
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from . import config, dependencies, ir, metrics
|
||||
from .dependencies import MemoryDep, StarDep
|
||||
from .dependencies import StarDep
|
||||
from .sizevars import SimplifyIndexing
|
||||
from .utils import cache_on_self, cmp, dynamo_utils, has_triton
|
||||
from .virtualized import V
|
||||
@ -116,6 +116,9 @@ class BaseSchedulerNode:
|
||||
def get_mutations(self):
|
||||
return self.node.get_mutation_names()
|
||||
|
||||
def has_aliasing_or_mutation(self):
|
||||
return bool(self.get_aliases() or self.get_mutations())
|
||||
|
||||
def set_read_writes(self, rw: dependencies.ReadWrites):
|
||||
self.read_writes: dependencies.ReadWrites = rw
|
||||
self.unmet_dependencies = self.read_writes.reads
|
||||
@ -162,9 +165,7 @@ class BaseSchedulerNode:
|
||||
return False
|
||||
|
||||
def allocate(self):
|
||||
from .codegen.triton_template import should_use_template
|
||||
|
||||
if self.node.should_allocate() or should_use_template(self.node):
|
||||
if self.node.should_allocate():
|
||||
# if self.node should allocate or
|
||||
# if self.node is generated by TritonKernelTemplates
|
||||
# because Triton kernel could not allocate tensor itself
|
||||
@ -223,32 +224,6 @@ class ExternKernelSchedulerNode(BaseSchedulerNode):
|
||||
return True
|
||||
|
||||
|
||||
class TemplateSchedulerNode(BaseSchedulerNode):
|
||||
def __init__(self, scheduler: "Scheduler", node: ir.ExternKernel, group_fn):
|
||||
super().__init__(scheduler, node)
|
||||
(self._sizes, self._stride) = node.get_group_stride()
|
||||
self.group = (node.get_device(), group_fn(self._sizes))
|
||||
self.set_read_writes(node.get_read_writes())
|
||||
self.update_dep_type()
|
||||
|
||||
def is_template(self):
|
||||
return True
|
||||
|
||||
def update_dep_type(self):
|
||||
assert len(self.read_writes.writes) == 1
|
||||
write = self.read_writes.writes.pop()
|
||||
if isinstance(write, StarDep):
|
||||
name = write.name
|
||||
canonicalized_index, canonicalized_size = self.node.canonicalize()
|
||||
new_dep = MemoryDep(name, canonicalized_index, canonicalized_size)
|
||||
self.read_writes.writes.add(new_dep)
|
||||
else:
|
||||
self.read_writes.writes.add(write)
|
||||
|
||||
def get_ranges(self):
|
||||
return self._sizes
|
||||
|
||||
|
||||
class NopKernelSchedulerNode(BaseSchedulerNode):
|
||||
pass
|
||||
|
||||
@ -263,9 +238,15 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
|
||||
self.group = (node.get_device(), group_fn(self._sizes))
|
||||
|
||||
self.set_read_writes(
|
||||
dependencies.extract_read_writes(self._body, *self._sizes, normalize=True)
|
||||
)
|
||||
if self.is_template():
|
||||
self.set_read_writes(node.normalized_read_writes())
|
||||
else:
|
||||
self.set_read_writes(
|
||||
dependencies.extract_read_writes(
|
||||
self._body, *self._sizes, normalize=True
|
||||
)
|
||||
)
|
||||
|
||||
if self.is_reduction():
|
||||
# reduction has last (reduced) dim in its sizes, and some
|
||||
# downstream dependencies get confused by it
|
||||
@ -303,7 +284,10 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
return self._sizes
|
||||
|
||||
def is_reduction(self):
|
||||
return bool(self.node.data.get_reduction_type())
|
||||
return bool(self.node.get_reduction_type())
|
||||
|
||||
def is_template(self):
|
||||
return isinstance(self.node, ir.TemplateBuffer)
|
||||
|
||||
def allocate(self):
|
||||
if (
|
||||
@ -313,8 +297,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
):
|
||||
return super().allocate()
|
||||
|
||||
if config.inplace_buffers:
|
||||
from .codegen.triton_template import should_use_template
|
||||
if config.inplace_buffers and getattr(V.kernel, "mutations", None) is not None:
|
||||
from .codegen.wrapper import buffer_reuse_key
|
||||
|
||||
ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name)
|
||||
@ -338,7 +321,6 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
input_node.node.get_layout(),
|
||||
(ir.MultiOutputLayout, ir.MutationLayout, ir.AliasedLayout),
|
||||
)
|
||||
and not should_use_template(input_node.node)
|
||||
and buffer_reuse_key(input_node.node)
|
||||
== buffer_reuse_key(self.node)
|
||||
):
|
||||
@ -398,7 +380,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
return dependencies.extract_read_writes(fn, sizes)
|
||||
|
||||
def can_inplace(self, read_dep: dependencies.MemoryDep):
|
||||
if self.get_aliases():
|
||||
if self.get_aliases() or self.is_template():
|
||||
return False
|
||||
if len(self.read_writes.writes) == 1 and hasattr(read_dep, "index"):
|
||||
write_dep = next(iter(self.read_writes.writes))
|
||||
@ -482,6 +464,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
def get_device(self):
|
||||
return self.group[0]
|
||||
|
||||
@cache_on_self
|
||||
def has_aliasing_or_mutation(self):
|
||||
return any(x.has_aliasing_or_mutation() for x in self.snodes)
|
||||
|
||||
# None of these need to be implemented, as a FusedSchedulerNode is just an
|
||||
# abstraction for scheduling purposes
|
||||
def update_mutated_names(self, renames: Dict[str, str]):
|
||||
@ -561,8 +547,6 @@ class NodeUser:
|
||||
class Scheduler:
|
||||
@dynamo_utils.dynamo_timed
|
||||
def __init__(self, nodes):
|
||||
from .codegen.triton_template import should_use_template
|
||||
|
||||
super(Scheduler, self).__init__()
|
||||
self.backends = {}
|
||||
|
||||
@ -577,12 +561,9 @@ class Scheduler:
|
||||
), "All nodes passed to scheduling must have an origin"
|
||||
if node.is_no_op():
|
||||
self.nodes.append(NopKernelSchedulerNode(self, node))
|
||||
elif isinstance(node, ir.ComputedBuffer):
|
||||
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
|
||||
group_fn = self.get_backend(node.get_device()).group_fn
|
||||
self.nodes.append(SchedulerNode(self, node, group_fn))
|
||||
elif isinstance(node, ir.ExternKernel) and should_use_template(node):
|
||||
group_fn = self.get_backend(node.get_device()).group_fn
|
||||
self.nodes.append(TemplateSchedulerNode(self, node, group_fn))
|
||||
elif isinstance(node, ir.ExternKernel):
|
||||
self.nodes.append(ExternKernelSchedulerNode(self, node))
|
||||
else:
|
||||
@ -901,6 +882,12 @@ class Scheduler:
|
||||
return False # node2 must go before node1
|
||||
if node2.is_template():
|
||||
return False # only epilogues
|
||||
if node1.is_template() and (
|
||||
node2.has_aliasing_or_mutation()
|
||||
or node2.is_reduction()
|
||||
or not config.epilogue_fusion
|
||||
):
|
||||
return False
|
||||
|
||||
device = node1.get_device()
|
||||
if device != node2.get_device():
|
||||
@ -919,14 +906,8 @@ class Scheduler:
|
||||
# node2 depends on node1 outputs
|
||||
if not self.can_fuse_vertical(node1, node2):
|
||||
return False
|
||||
if node1.is_template():
|
||||
from .codegen.triton_template import template_can_fuse
|
||||
|
||||
return template_can_fuse(node1, node2)
|
||||
return self.get_backend(device).can_fuse_vertical(node1, node2)
|
||||
else: # nodes don't depend on each other, but may have common reads
|
||||
if node1.is_template():
|
||||
return False
|
||||
return self.get_backend(device).can_fuse_horizontal(node1, node2)
|
||||
|
||||
def can_fuse_vertical(self, node1, node2):
|
||||
@ -981,6 +962,7 @@ class Scheduler:
|
||||
abs(node2.min_order - node1.max_order),
|
||||
)
|
||||
return (
|
||||
node1.is_template() == config.epilogue_fusion_first and memory_score > 0,
|
||||
node1.is_reduction() == node2.is_reduction() and memory_score > 0,
|
||||
memory_score,
|
||||
proximity_score,
|
||||
@ -1074,16 +1056,6 @@ class Scheduler:
|
||||
node.codegen(V.graph.wrapper_code)
|
||||
self.free_buffers()
|
||||
|
||||
def codegen_template_call(
|
||||
self, scheduler_node: Union[FusedSchedulerNode, TemplateSchedulerNode]
|
||||
):
|
||||
from .codegen.triton_template import template_codegen
|
||||
|
||||
node, *epilogue = scheduler_node.get_nodes()
|
||||
node.allocate()
|
||||
template_codegen(self, node, epilogue)
|
||||
self.free_buffers()
|
||||
|
||||
def create_backend(self, device: torch.device):
|
||||
assert (
|
||||
device.type != "cuda" or device.index is not None
|
||||
@ -1141,7 +1113,8 @@ class Scheduler:
|
||||
self.buffer_names_to_free.update(node.last_usage)
|
||||
|
||||
if node.is_template():
|
||||
self.codegen_template_call(node)
|
||||
node, *epilogue = node.get_nodes()
|
||||
self.get_backend(device).codegen_template(node, epilogue)
|
||||
elif node.is_extern():
|
||||
self.codegen_extern_call(node)
|
||||
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
||||
|
||||
681
torch/_inductor/select_algorithm.py
Normal file
681
torch/_inductor/select_algorithm.py
Normal file
@ -0,0 +1,681 @@
|
||||
import builtins
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import sys
|
||||
import textwrap
|
||||
from io import StringIO
|
||||
|
||||
from typing import Any, List
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import counters, identity
|
||||
|
||||
from . import config, ir
|
||||
from .codecache import code_hash, DiskCache, PyCodeCache
|
||||
|
||||
from .codegen.common import IndentedBuffer
|
||||
from .codegen.triton import config_of, signature_of, texpr, TritonKernel, TritonPrinter
|
||||
|
||||
from .utils import do_bench, sympy_dot, sympy_product
|
||||
from .virtualized import V
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# correctness checks struggle with fp16/tf32
|
||||
VERIFY = False # dict(atol=1, rtol=0.05)
|
||||
PRINT_AUTOTUNE = True
|
||||
|
||||
|
||||
class KernelNamespace:
|
||||
pass
|
||||
|
||||
|
||||
# these objects are imported from the generated wrapper code
|
||||
template_kernels = KernelNamespace()
|
||||
extern_kernels = KernelNamespace()
|
||||
|
||||
|
||||
class TritonTemplateKernel(TritonKernel):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name,
|
||||
input_nodes,
|
||||
output_node,
|
||||
defines,
|
||||
num_stages,
|
||||
num_warps,
|
||||
grid_fn,
|
||||
meta,
|
||||
call_sizes,
|
||||
use_jit=True,
|
||||
prefix_args=0,
|
||||
suffix_args=0,
|
||||
epilogue_fn=identity,
|
||||
):
|
||||
super().__init__(sympy_product(output_node.get_size()), sympy.Integer(1))
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node = output_node
|
||||
self.named_input_nodes = {}
|
||||
self.defines = defines
|
||||
self.kernel_name = kernel_name
|
||||
self.template_mask = None
|
||||
self.use_jit = use_jit
|
||||
self.num_stages = num_stages
|
||||
self.num_warps = num_warps
|
||||
self.grid_fn = grid_fn
|
||||
self.meta = meta
|
||||
self.call_sizes = call_sizes
|
||||
# for templates with fixed epilogues
|
||||
self.prefix_args = prefix_args
|
||||
self.suffix_args = suffix_args
|
||||
self.epilogue_fn = epilogue_fn
|
||||
|
||||
def jit_line(self):
|
||||
if self.use_jit:
|
||||
return "@triton.jit"
|
||||
|
||||
argdefs, _, signature = self.args.python_argdefs()
|
||||
triton_meta = {
|
||||
"signature": dict(enumerate(map(signature_of, signature))),
|
||||
"device": V.graph.scheduler.current_device.index,
|
||||
"constants": {},
|
||||
}
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
return (
|
||||
f"@template(num_stages={self.num_stages}, num_warps={self.num_warps}, meta={triton_meta!r})\n"
|
||||
+ "@triton.jit"
|
||||
)
|
||||
|
||||
def def_kernel(self, *argnames):
|
||||
"""
|
||||
Hook called from template code to generate function def and
|
||||
needed args.
|
||||
"""
|
||||
assert all(isinstance(x, str) for x in argnames)
|
||||
renames = IndentedBuffer(initial_indent=1)
|
||||
|
||||
named_args = self.input_nodes[
|
||||
self.prefix_args : len(self.input_nodes) - self.suffix_args
|
||||
]
|
||||
|
||||
assert len(argnames) == len(named_args), (
|
||||
len(argnames),
|
||||
len(named_args),
|
||||
self.prefix_args,
|
||||
len(self.input_nodes),
|
||||
)
|
||||
|
||||
for input_node in self.input_nodes[: self.prefix_args]:
|
||||
# get args in correct order
|
||||
self.args.input(input_node.get_name())
|
||||
|
||||
for name, input_node in zip(argnames, named_args):
|
||||
arg_name = f"arg_{name}"
|
||||
self.named_input_nodes[name] = input_node
|
||||
self.args.input_buffers[input_node.get_name()] = arg_name
|
||||
if input_node.get_layout().offset == 0:
|
||||
renames.writeline(f"{name} = {arg_name}")
|
||||
else:
|
||||
offset = texpr(self.rename_indexing(input_node.get_layout().offset))
|
||||
renames.writeline(f"{name} = {arg_name} + {offset}")
|
||||
|
||||
for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
|
||||
# get args in correct order
|
||||
self.args.input(input_node.get_name())
|
||||
|
||||
arg_defs, *_ = self.args.python_argdefs()
|
||||
return "\n".join(
|
||||
[
|
||||
"import triton.language as tl",
|
||||
"import triton",
|
||||
f"from {config.inductor_import}.triton_ops.autotune import template",
|
||||
f"from {config.inductor_import}.utils import instance_descriptor",
|
||||
"",
|
||||
self.jit_line(),
|
||||
f"def {self.kernel_name}({', '.join(arg_defs)}):",
|
||||
self.defines,
|
||||
renames.getvalue(),
|
||||
]
|
||||
)
|
||||
|
||||
def size(self, name: str, index: int):
|
||||
"""
|
||||
Hook called from template code to get the size of an arg.
|
||||
Will add needed args to pass it in if it is dynamic.
|
||||
"""
|
||||
assert isinstance(name, str)
|
||||
assert isinstance(index, int)
|
||||
val = self.named_input_nodes[name].get_size()[index]
|
||||
return texpr(self.rename_indexing(val))
|
||||
|
||||
def stride(self, name, index):
|
||||
"""
|
||||
Hook called from template code to get the stride of an arg.
|
||||
Will add needed args to pass it in if it is dynamic.
|
||||
"""
|
||||
assert isinstance(name, str)
|
||||
assert isinstance(index, int)
|
||||
val = self.named_input_nodes[name].get_stride()[index]
|
||||
return texpr(self.rename_indexing(val))
|
||||
|
||||
def store_output(self, indices, val, mask):
|
||||
"""
|
||||
Hook called from template code to store the final output
|
||||
(if the buffer hasn't been optimized away), then append any
|
||||
epilogue fusions.
|
||||
"""
|
||||
assert isinstance(indices, (list, tuple))
|
||||
assert isinstance(val, str)
|
||||
assert isinstance(mask, str)
|
||||
if self.template_mask is None:
|
||||
indices = list(map(TritonPrinter.paren, indices))
|
||||
index_symbols = [sympy.Symbol(x) for x in indices]
|
||||
lengths = [
|
||||
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
|
||||
]
|
||||
assert len(indices) == len(lengths)
|
||||
|
||||
# glue to make generated code use same indexing from template
|
||||
for name, range_tree_entry in zip(
|
||||
indices, self.range_trees[0].construct_entries(lengths)
|
||||
):
|
||||
range_tree_entry.set_name(name)
|
||||
contiguous_index = sympy_dot(
|
||||
ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
|
||||
)
|
||||
self.body.writeline("xindex = " + texpr(contiguous_index))
|
||||
self.range_trees[0].lookup(
|
||||
sympy.Integer(1), sympy_product(lengths)
|
||||
).set_name("xindex")
|
||||
self.template_mask = mask
|
||||
self.template_indices = indices
|
||||
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
|
||||
if output_index == contiguous_index:
|
||||
output_index = sympy.Symbol("xindex")
|
||||
|
||||
epilogue_args = [val]
|
||||
for input_node in itertools.chain(
|
||||
self.input_nodes[: self.prefix_args],
|
||||
self.input_nodes[len(self.input_nodes) - self.suffix_args :],
|
||||
):
|
||||
input_node.freeze_layout()
|
||||
epilogue_args.append(input_node.make_loader()(index_symbols))
|
||||
|
||||
V.ops.store(
|
||||
self.output_node.get_name(),
|
||||
output_index,
|
||||
self.epilogue_fn(*epilogue_args),
|
||||
)
|
||||
assert self.template_mask == mask
|
||||
self.codegen_body()
|
||||
return textwrap.indent(self.body.getvalue(), " ").strip()
|
||||
|
||||
def make_load(self, name, indices, mask):
|
||||
"""
|
||||
Optional helper called from template code to generate the code
|
||||
needed to load from an tensor.
|
||||
"""
|
||||
assert isinstance(indices, (list, tuple))
|
||||
assert isinstance(name, str)
|
||||
assert isinstance(mask, str)
|
||||
stride = self.named_input_nodes[name].get_stride()
|
||||
indices = list(map(TritonPrinter.paren, indices))
|
||||
assert len(indices) == len(stride)
|
||||
index = " + ".join(
|
||||
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
|
||||
)
|
||||
return f"tl.load({name} + ({index}), {mask})"
|
||||
|
||||
def template_env(self):
|
||||
"""
|
||||
Generate the namespace visible in the template.
|
||||
"""
|
||||
return {
|
||||
fn.__name__: fn
|
||||
for fn in [
|
||||
self.def_kernel,
|
||||
self.size,
|
||||
self.stride,
|
||||
self.store_output,
|
||||
self.make_load,
|
||||
]
|
||||
}
|
||||
|
||||
def indexing(
|
||||
self,
|
||||
index: sympy.Expr,
|
||||
*,
|
||||
copy_shape=None,
|
||||
dense_indexing=False,
|
||||
):
|
||||
"""
|
||||
Override the default indexing to use our custom mask and force
|
||||
dense indexing.
|
||||
"""
|
||||
result, *mask = super().indexing(
|
||||
index,
|
||||
dense_indexing=False,
|
||||
copy_shape=copy_shape,
|
||||
override_mask=self.template_mask,
|
||||
)
|
||||
result += f" + tl.zeros({self.template_mask}.shape, tl.int32)"
|
||||
return (result, *mask)
|
||||
|
||||
def initialize_range_tree(self, pid_cache):
|
||||
super().initialize_range_tree(pid_cache)
|
||||
# ignore default codegen
|
||||
self.body.clear()
|
||||
self.indexing_code.clear()
|
||||
|
||||
def call_kernel(self, code, name: str):
|
||||
_, call_args, _ = self.args.python_argdefs()
|
||||
|
||||
for i in range(len(call_args)):
|
||||
if V.graph.is_unspec_arg(call_args[i]):
|
||||
call_args[i] = call_args[i] + ".item()"
|
||||
call_args = ", ".join(call_args)
|
||||
|
||||
stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
|
||||
|
||||
V.graph.wrapper_code.add_import_once(f"import {self.grid_fn.__module__}")
|
||||
meta = V.graph.wrapper_code.add_meta_once(self.meta)
|
||||
|
||||
grid_call = [texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes] + [
|
||||
meta
|
||||
]
|
||||
grid_call = (
|
||||
f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
|
||||
)
|
||||
code.writeline(
|
||||
f"{name}.run({call_args}, grid={grid_call}, stream={stream_name})"
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _jinja2_env():
|
||||
try:
|
||||
import jinja2
|
||||
|
||||
return jinja2.Environment(
|
||||
undefined=jinja2.StrictUndefined,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
class TritonTemplate:
|
||||
index_counter = itertools.count()
|
||||
all_templates = dict()
|
||||
|
||||
@staticmethod
|
||||
def _template_from_string(source):
|
||||
env = _jinja2_env()
|
||||
if env is not None:
|
||||
return env.from_string(source)
|
||||
return None
|
||||
|
||||
def __init__(self, name: str, grid: Any, source: str, debug=False):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.grid = grid
|
||||
self.template = self._template_from_string(source)
|
||||
assert name not in self.all_templates, "duplicate template name"
|
||||
self.all_templates[name] = self
|
||||
self.debug = debug
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_nodes,
|
||||
layout,
|
||||
num_stages,
|
||||
num_warps,
|
||||
prefix_args=0,
|
||||
suffix_args=0,
|
||||
epilogue_fn=identity,
|
||||
**kwargs,
|
||||
):
|
||||
assert self.template, "requires jinja2"
|
||||
defines = StringIO()
|
||||
for name, val in kwargs.items():
|
||||
defines.write(f" {name} : tl.constexpr = {val}\n")
|
||||
defines = defines.getvalue()
|
||||
|
||||
fake_out = ir.Buffer("buf_out", layout)
|
||||
kernel_name = f"triton_{self.name}"
|
||||
|
||||
kernel_options = dict(
|
||||
input_nodes=input_nodes,
|
||||
defines=defines,
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
grid_fn=self.grid,
|
||||
meta=kwargs,
|
||||
call_sizes=layout.size,
|
||||
prefix_args=prefix_args,
|
||||
suffix_args=suffix_args,
|
||||
epilogue_fn=epilogue_fn,
|
||||
)
|
||||
with patch.object(
|
||||
V.graph, "get_dtype", self.fake_get_dtype(fake_out)
|
||||
), TritonTemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
output_node=fake_out,
|
||||
use_jit=True,
|
||||
**kernel_options,
|
||||
) as kernel:
|
||||
# need to do call render twice to get all the needed args right
|
||||
self.template.render(
|
||||
**kernel.template_env(),
|
||||
**kwargs,
|
||||
)
|
||||
code = self.template.render(
|
||||
**kernel.template_env(),
|
||||
**kwargs,
|
||||
)
|
||||
if self.debug:
|
||||
print("Generated Code:\n", code)
|
||||
mod = PyCodeCache.load(code)
|
||||
run = getattr(mod, kernel_name).run
|
||||
_, call_args, _ = kernel.args.python_argdefs()
|
||||
|
||||
expected_args = [x.get_name() for x in input_nodes] + [fake_out.get_name()]
|
||||
assert list(call_args) == expected_args, (call_args, expected_args)
|
||||
extra_args = V.graph.sizevars.size_hints(
|
||||
map(sympy.expand, call_args[len(expected_args) :])
|
||||
)
|
||||
assert not extra_args, "TODO: dynamic shapes"
|
||||
|
||||
def call(*args, out):
|
||||
return run(
|
||||
*args,
|
||||
out,
|
||||
*extra_args,
|
||||
grid=self.grid(*out.size(), kwargs),
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
call.key = mod.key
|
||||
call.__file__ = mod.__file__
|
||||
|
||||
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
|
||||
setattr(template_kernels, kernel_hash_name, call)
|
||||
|
||||
def make_kernel_render(out_node):
|
||||
kernel = TritonTemplateKernel(
|
||||
kernel_name="KERNEL_NAME",
|
||||
output_node=out_node,
|
||||
use_jit=False,
|
||||
**kernel_options,
|
||||
)
|
||||
render = functools.partial(
|
||||
self.template.render,
|
||||
**kernel.template_env(),
|
||||
**kwargs,
|
||||
)
|
||||
return kernel, render
|
||||
|
||||
return TritonTemplateCaller(
|
||||
kernel_hash_name, input_nodes, layout, make_kernel_render
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fake_get_dtype(fake_out):
|
||||
_get_dtype_real = V.graph.get_dtype
|
||||
|
||||
def get_dtype(name):
|
||||
if name == fake_out.get_name():
|
||||
return fake_out.get_dtype()
|
||||
return _get_dtype_real(name)
|
||||
|
||||
return get_dtype
|
||||
|
||||
|
||||
class ExternKernelChoice:
|
||||
def __init__(self, kernel, cpp_kernel=None, *, name=None):
|
||||
super().__init__()
|
||||
name = name or kernel.__name__
|
||||
assert callable(kernel)
|
||||
assert not hasattr(extern_kernels, name), "duplicate extern kernel"
|
||||
self.name = name
|
||||
self.cpp_kernel = cpp_kernel
|
||||
setattr(extern_kernels, name, kernel)
|
||||
|
||||
def to_callable(self):
|
||||
return getattr(extern_kernels, self.name)
|
||||
|
||||
def call_name(self):
|
||||
return f"extern_kernels.{self.name}"
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def hash_key(self):
|
||||
fn = self.to_callable()
|
||||
parts = [
|
||||
self.name,
|
||||
getattr(fn, "__name__", ""),
|
||||
getattr(fn, "__module__", ""),
|
||||
]
|
||||
try:
|
||||
parts.append(inspect.getsource(fn))
|
||||
except Exception:
|
||||
pass
|
||||
return code_hash("-".join(parts))
|
||||
|
||||
def bind(self, input_nodes, layout, **kwargs):
|
||||
return ExternKernelCaller(self, input_nodes, layout, kwargs)
|
||||
|
||||
|
||||
class ChoiceCaller:
|
||||
def __init__(self, name, input_nodes, layout):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.layout = layout
|
||||
self.input_nodes = input_nodes
|
||||
|
||||
|
||||
class TritonTemplateCaller(ChoiceCaller):
|
||||
def __init__(self, name, input_nodes, layout, make_kernel_render):
|
||||
super().__init__(name, input_nodes, layout)
|
||||
self.make_kernel_render = make_kernel_render
|
||||
|
||||
def __str__(self):
|
||||
return f"TritonTemplateCaller({self.to_callable().__file__})"
|
||||
|
||||
def call_name(self):
|
||||
return f"template_kernels.{self.name}"
|
||||
|
||||
def to_callable(self):
|
||||
return getattr(template_kernels, self.name)
|
||||
|
||||
def hash_key(self):
|
||||
return self.to_callable().key
|
||||
|
||||
def output_node(self):
|
||||
return ir.TensorBox.create(
|
||||
ir.TemplateBuffer(
|
||||
layout=self.layout,
|
||||
inputs=self.input_nodes,
|
||||
make_kernel_render=self.make_kernel_render,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExternKernelCaller(ChoiceCaller):
|
||||
def __init__(self, choice: ExternKernelChoice, input_nodes, layout, kwargs=None):
|
||||
super().__init__(choice.name, input_nodes, layout)
|
||||
self.choice = choice
|
||||
self.kwargs = kwargs or {}
|
||||
|
||||
def to_callable(self):
|
||||
fn = self.choice.to_callable()
|
||||
if self.kwargs:
|
||||
return functools.partial(fn, **self.kwargs)
|
||||
else:
|
||||
return fn
|
||||
|
||||
def hash_key(self):
|
||||
return "/".join(
|
||||
[
|
||||
self.choice.hash_key(),
|
||||
repr(self.kwargs),
|
||||
]
|
||||
)
|
||||
|
||||
def output_node(self):
|
||||
return ir.TensorBox.create(
|
||||
ir.ExternKernelOut(
|
||||
layout=self.layout,
|
||||
inputs=self.input_nodes,
|
||||
kernel=self.choice.call_name(),
|
||||
cpp_kernel=self.choice.cpp_kernel,
|
||||
kwargs=self.kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AlgorithmSelectorCache(DiskCache):
|
||||
def __call__(self, choices: List[ChoiceCaller], input_nodes, layout):
|
||||
if len(choices) == 1:
|
||||
return choices[0].output_node()
|
||||
|
||||
def autotune():
|
||||
benchmark_fn = self.make_benchmark_fn(choices, input_nodes, layout)
|
||||
timings = {}
|
||||
for choice in choices:
|
||||
try:
|
||||
timings[choice] = benchmark_fn(
|
||||
choice.to_callable(), isinstance(choice, ExternKernelCaller)
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "invalid argument" in str(e):
|
||||
msg = textwrap.dedent(
|
||||
f"""
|
||||
{e}
|
||||
|
||||
From choice {choices.index(choice)}: {choice}
|
||||
|
||||
This may mean this GPU is too small for max_autotune mode.
|
||||
"""
|
||||
).strip()
|
||||
if VERIFY:
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
raise
|
||||
except AssertionError as e:
|
||||
raise AssertionError(
|
||||
f"Incorrect result from choice {choices.index(choice)} {choice}\n\n{e}"
|
||||
)
|
||||
|
||||
self.log_results(choices[0].name, input_nodes, timings)
|
||||
best_choice = builtins.min(timings, key=timings.__getitem__)
|
||||
return choices.index(best_choice)
|
||||
|
||||
counters["inductor"]["select_algorithm_autotune"] += 1
|
||||
key = [x.hash_key() for x in choices] + [self.key_of(x) for x in input_nodes]
|
||||
return choices[self.lookup(key, autotune)].output_node()
|
||||
|
||||
@classmethod
|
||||
def make_benchmark_fn(
|
||||
cls,
|
||||
choices,
|
||||
input_nodes,
|
||||
layout,
|
||||
):
|
||||
example_inputs = [cls.benchmark_example_value(x) for x in input_nodes]
|
||||
example_inputs_extern = list(example_inputs)
|
||||
for i in range(len(example_inputs)):
|
||||
if input_nodes[i].get_layout().offset != 0:
|
||||
offset = V.graph.sizevars.size_hint(input_nodes[i].get_layout().offset)
|
||||
data = example_inputs_extern[i]
|
||||
example_inputs_extern[i] = torch.as_strided(
|
||||
data, data.size(), data.stride(), offset
|
||||
)
|
||||
|
||||
out = cls.benchmark_example_value(layout)
|
||||
out_extern = torch.as_strided(
|
||||
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
||||
)
|
||||
if VERIFY:
|
||||
choices[0].to_callable()(*example_inputs_extern, out=out_extern)
|
||||
expected = out_extern.clone()
|
||||
|
||||
def benchmark(algo, is_extern):
|
||||
out.zero_()
|
||||
if is_extern:
|
||||
result = do_bench(lambda: algo(*example_inputs_extern, out=out_extern))
|
||||
else:
|
||||
result = do_bench(lambda: algo(*example_inputs, out=out))
|
||||
if VERIFY:
|
||||
torch.testing.assert_close(out_extern, expected, **VERIFY)
|
||||
torch.cuda.synchronize() # shake out any CUDA errors
|
||||
return result
|
||||
|
||||
return benchmark
|
||||
|
||||
@staticmethod
|
||||
def log_results(name, input_nodes, timings):
|
||||
if not PRINT_AUTOTUNE:
|
||||
return
|
||||
sizes = ", ".join(
|
||||
[
|
||||
"x".join(map(str, V.graph.sizevars.size_hints(n.get_size())))
|
||||
for n in input_nodes
|
||||
]
|
||||
)
|
||||
top_k = sorted(timings, key=timings.__getitem__)[:10]
|
||||
best = top_k[0]
|
||||
best_time = timings[best][0]
|
||||
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
|
||||
for choice in top_k:
|
||||
result = timings[choice]
|
||||
sys.stderr.write(
|
||||
f" {choice.name} {result[0]:.4f}s {best_time/result[0]:.1%}\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def benchmark_example_value(node):
|
||||
"""
|
||||
Convert an ir.Buffer into a concrete torch.Tensor we can use for
|
||||
benchmarking.
|
||||
"""
|
||||
if isinstance(node, ir.Layout):
|
||||
node = ir.Buffer("fake", node)
|
||||
return rand_strided(
|
||||
V.graph.sizevars.size_hints(node.get_size()),
|
||||
V.graph.sizevars.size_hints(node.get_stride()),
|
||||
device=node.get_device(),
|
||||
dtype=node.get_dtype(),
|
||||
extra_size=V.graph.sizevars.size_hint(node.get_layout().offset),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def key_of(node):
|
||||
"""
|
||||
Extract the pieces of an ir.Buffer that we should invalidate cached
|
||||
autotuning results on.
|
||||
"""
|
||||
sizevars = V.graph.sizevars
|
||||
return (
|
||||
node.get_device().type,
|
||||
str(node.get_dtype()),
|
||||
*sizevars.size_hints(node.get_size()),
|
||||
*sizevars.size_hints(node.get_stride()),
|
||||
sizevars.size_hint(node.get_layout().offset),
|
||||
)
|
||||
|
||||
|
||||
autotune_select_algorithm = AlgorithmSelectorCache(__name__)
|
||||
|
||||
|
||||
def realize_inputs(*args):
|
||||
if len(args) == 1:
|
||||
return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
|
||||
return [realize_inputs(x) for x in args]
|
||||
@ -367,6 +367,9 @@ class SizeVarAllocator(object):
|
||||
out = sympy_subs(sympy.expand(expr), self.var_to_val)
|
||||
return int(out)
|
||||
|
||||
def size_hints(self, exprs: List[Expr]) -> int:
|
||||
return tuple(self.size_hint(x) for x in exprs)
|
||||
|
||||
def _lru_cache(self, fn, maxsize=None):
|
||||
"""
|
||||
Wrapper around functools.lru_cache that clears when replacements
|
||||
|
||||
@ -3,6 +3,5 @@ from ..utils import has_triton
|
||||
if has_triton():
|
||||
from .conv import _conv, conv
|
||||
from .conv1x1 import _conv1x1, conv1x1
|
||||
from .matmul import _matmul_out, matmul_out
|
||||
|
||||
__all__ = ["_conv", "conv", "_conv1x1", "conv1x1", "_matmul_out", "matmul_out"]
|
||||
__all__ = ["_conv", "conv", "_conv1x1", "conv1x1"]
|
||||
|
||||
@ -12,7 +12,6 @@ import torch
|
||||
|
||||
from .. import config
|
||||
from ..ir import ReductionHint, TileHint
|
||||
from ..triton_ops.mm_perf_model import estimate_matmul_time
|
||||
from ..utils import conditional_product, dynamo_utils, has_triton
|
||||
from .conv_perf_model import (
|
||||
early_config_prune as conv_early_config_prune,
|
||||
@ -106,8 +105,10 @@ class CachingAutotuner(KernelInterface):
|
||||
exec(
|
||||
f"""
|
||||
def launcher({', '.join(def_args)}, grid, stream):
|
||||
# set_device(current_device()) # TODO(jansel): is this needed?
|
||||
grid_0, grid_1, grid_2 = grid(grid_meta)
|
||||
if callable(grid):
|
||||
grid_0, grid_1, grid_2 = grid(grid_meta)
|
||||
else:
|
||||
grid_0, grid_1, grid_2 = grid
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
|
||||
stream, bin.cu_function, None, None, None,
|
||||
{', '.join(call_args)})
|
||||
@ -398,7 +399,7 @@ def pointwise(size_hints, meta, tile_hint=None, filename=None):
|
||||
if len(size_hints) == 1:
|
||||
return cached_autotune([triton_config(size_hints, 1024)], meta=meta)
|
||||
if len(size_hints) == 2:
|
||||
if not config.triton.autotune or tile_hint == TileHint.SQUARE:
|
||||
if not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE:
|
||||
return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
|
||||
return cached_autotune(
|
||||
[
|
||||
@ -412,7 +413,7 @@ def pointwise(size_hints, meta, tile_hint=None, filename=None):
|
||||
filename=filename,
|
||||
)
|
||||
if len(size_hints) == 3:
|
||||
if not config.triton.autotune:
|
||||
if not config.triton.autotune_pointwise:
|
||||
return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
|
||||
return cached_autotune(
|
||||
[
|
||||
@ -448,7 +449,7 @@ def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
|
||||
return cached_autotune([outer_config], meta=meta)
|
||||
elif reduction_hint == ReductionHint.OUTER_TINY:
|
||||
return cached_autotune([tiny_config], meta=meta)
|
||||
if not config.triton.autotune:
|
||||
if not config.triton.autotune_pointwise:
|
||||
return cached_autotune(
|
||||
[triton_config_reduction(size_hints, 32, 128)], meta=meta
|
||||
)
|
||||
@ -469,6 +470,15 @@ def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
|
||||
raise NotImplementedError(f"size_hints: {size_hints}")
|
||||
|
||||
|
||||
def template(num_stages, num_warps, meta, filename=None):
|
||||
"""
|
||||
Compile a triton template
|
||||
"""
|
||||
return cached_autotune(
|
||||
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
|
||||
)
|
||||
|
||||
|
||||
def conv_heuristics():
|
||||
configs = [
|
||||
triton.Config(
|
||||
@ -552,126 +562,6 @@ def conv_heuristics():
|
||||
return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
|
||||
|
||||
|
||||
def mm_heuristics():
|
||||
from triton import heuristics
|
||||
|
||||
mm_heuristic = heuristics(
|
||||
{
|
||||
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
||||
}
|
||||
)
|
||||
return mm_heuristic
|
||||
|
||||
|
||||
def mm_autotune(get_io_bound_configs=False):
|
||||
from triton.ops.matmul import get_configs_io_bound
|
||||
from triton.ops.matmul_perf_model import early_config_prune as mm_early_config_prune
|
||||
|
||||
configs = [
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
# good for int8
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
]
|
||||
if get_io_bound_configs:
|
||||
configs += get_configs_io_bound()
|
||||
key = ["M", "N", "K"]
|
||||
prune_configs_by = {
|
||||
"early_config_prune": mm_early_config_prune,
|
||||
"perf_model": estimate_matmul_time,
|
||||
"top_k": 10,
|
||||
}
|
||||
return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
|
||||
|
||||
|
||||
def grid(xnumel, ynumel=None, znumel=None):
|
||||
"""Helper function to compute triton grids"""
|
||||
|
||||
|
||||
@ -1,274 +0,0 @@
|
||||
import torch
|
||||
|
||||
from ..utils import has_triton
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
# additional configs
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
# additional configs for K = 64
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
|
||||
num_stages=1,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
# + get_configs_io_bound(),
|
||||
key=["M", "N", "K"],
|
||||
#
|
||||
# key=["M", "N", "K"],
|
||||
# prune_configs_by={
|
||||
# "early_config_prune": early_config_prune,
|
||||
# "perf_model": estimate_matmul_time,
|
||||
# "top_k": 18,
|
||||
# },
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
bid = tl.program_id(2)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
A += bid * M * K
|
||||
B += bid * K * N
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for k in range(K, 0, -BLOCK_K * SPLIT_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
C += bid * M * N
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
def bmm_out(a, b, out):
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[2] == b.shape[1], "incompatible dimensions"
|
||||
B, M, K = a.shape
|
||||
_, _, N = b.shape
|
||||
# allocates output
|
||||
c = out
|
||||
# accumulator types
|
||||
ACC_TYPE = (
|
||||
tl.float32
|
||||
if a.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
else tl.int32
|
||||
)
|
||||
|
||||
# launch kernel
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
||||
META["SPLIT_K"],
|
||||
B,
|
||||
)
|
||||
|
||||
# grid = lambda META: (
|
||||
# triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
||||
# META["SPLIT_K"],
|
||||
# B,
|
||||
# )
|
||||
|
||||
# autotuner = _kernel[grid].kernel
|
||||
_kernel[grid](a, b, c, M, N, K, K, 1, N, 1, N, 1, GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
# print(autotuner.best_config)
|
||||
# print(autotuner.configs_timings)
|
||||
@ -1,136 +0,0 @@
|
||||
import torch
|
||||
|
||||
from ..utils import has_triton
|
||||
|
||||
if has_triton():
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .autotune import mm_autotune, mm_heuristics
|
||||
|
||||
@mm_heuristics()
|
||||
@mm_autotune(get_io_bound_configs=True)
|
||||
@triton.jit
|
||||
def _kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
allow_tf32: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for k in range(K, 0, -BLOCK_K * SPLIT_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
|
||||
acc += tl.dot(a, b, allow_tf32=allow_tf32)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
class _matmul_out:
|
||||
kernel = _kernel
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, out, allow_tf32=True):
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = out
|
||||
# accumulator types
|
||||
ACC_TYPE = (
|
||||
tl.float32
|
||||
if a.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
else tl.int32
|
||||
)
|
||||
|
||||
# launch kernel (grid defined as using def instead of lambda to pass `make lint`)
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
||||
META["SPLIT_K"],
|
||||
)
|
||||
|
||||
# grid = lambda META: (
|
||||
# triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
||||
# META["SPLIT_K"],
|
||||
# )
|
||||
_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
allow_tf32=allow_tf32,
|
||||
GROUP_M=8,
|
||||
ACC_TYPE=ACC_TYPE,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward(a, b, out, allow_tf32=True):
|
||||
return _matmul_out._call(a, b, out, allow_tf32)
|
||||
|
||||
matmul_out = _matmul_out.forward
|
||||
@ -1,90 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps,
|
||||
num_stages,
|
||||
A,
|
||||
B,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
SPLIT_K,
|
||||
debug=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""return estimated running time in ms
|
||||
= max(compute, loading) + store"""
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.ops.matmul_perf_model import (
|
||||
get_dram_gbps as get_dram_gbps,
|
||||
get_tflops as get_tflops,
|
||||
)
|
||||
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
dtype = A.dtype
|
||||
dtsize = A.element_size()
|
||||
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
num_cta_k = SPLIT_K
|
||||
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||
|
||||
# If the input is smaller than the block size
|
||||
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
||||
|
||||
# time to compute
|
||||
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
||||
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = _triton.runtime.num_sm(backend, device)
|
||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||
active_cta_ratio_bw1 = min(
|
||||
1, num_ctas / 32
|
||||
) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(
|
||||
min(1, (num_ctas - 32) / (108 - 32)), 0
|
||||
) # 32-108, remaining 5%
|
||||
dram_bw = get_dram_gbps(backend, device) * (
|
||||
active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05
|
||||
) # in GB/s
|
||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||
# assume 80% of (following) loads are in L2 cache
|
||||
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
|
||||
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
|
||||
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
|
||||
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
|
||||
# total
|
||||
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
||||
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
||||
# loading time in ms
|
||||
load_ms = total_dram / dram_bw + total_l2 / l2_bw
|
||||
|
||||
# estimate storing time
|
||||
store_bw = dram_bw * 0.6 # :o
|
||||
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
|
||||
if SPLIT_K == 1:
|
||||
store_ms = store_c_dram / store_bw
|
||||
else:
|
||||
reduce_bw = store_bw
|
||||
store_ms = store_c_dram / reduce_bw
|
||||
# c.zero_()
|
||||
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
|
||||
store_ms += zero_ms
|
||||
|
||||
total_time_ms = max(compute_ms, load_ms) + store_ms
|
||||
if debug:
|
||||
print(
|
||||
f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, "
|
||||
f"loading time: {load_ms}ms, store time: {store_ms}ms, "
|
||||
f"Activate CTAs: {active_cta_ratio*100}%"
|
||||
)
|
||||
return total_time_ms
|
||||
@ -1,6 +1,7 @@
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
@ -17,9 +18,11 @@ import sympy
|
||||
import torch
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
|
||||
from . import config
|
||||
from . import config, config as inductor_config
|
||||
from .cuda_properties import get_device_capability
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VarRanges = Dict[sympy.Expr, sympy.Expr]
|
||||
|
||||
# We import torchdynamo modules indirectly to allow a future rename to torch.dynamo
|
||||
@ -30,6 +33,13 @@ dynamo_optimizations = import_module(f"{config.dynamo_import}.optimizations")
|
||||
dynamo_testing = import_module(f"{config.dynamo_import}.testing")
|
||||
dynamo_utils = import_module(f"{config.dynamo_import}.utils")
|
||||
|
||||
try:
|
||||
from triton.testing import do_bench
|
||||
except ImportError:
|
||||
|
||||
def do_bench(*args, **kwargs):
|
||||
raise NotImplementedError("requires Triton")
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton():
|
||||
@ -432,3 +442,21 @@ class DeferredLineBase:
|
||||
|
||||
def __len__(self):
|
||||
return len(self.line)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def is_big_gpu(index):
|
||||
cores = torch.cuda.get_device_properties(index).multi_processor_count
|
||||
if cores < 80: # V100
|
||||
log.warning("not enough cuda cores to use max_autotune mode")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def use_triton_template(layout):
|
||||
return (
|
||||
inductor_config.max_autotune
|
||||
and layout.device.type == "cuda"
|
||||
and layout.dtype in (torch.float16, torch.bfloat16, torch.float32)
|
||||
and is_big_gpu(layout.device.index or 0)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user