mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
fix https://github.com/pytorch/pytorch/issues/165579 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165601 Approved by: https://github.com/yushangdi
1194 lines
39 KiB
Python
1194 lines
39 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import contextlib
|
|
import os
|
|
import unittest
|
|
from unittest import skipUnless
|
|
|
|
import numpy as np
|
|
import sympy
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch._dynamo.testing import rand_strided
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import config as inductor_config, ir, metrics
|
|
from torch._inductor.codegen.triton import TritonScheduling
|
|
from torch._inductor.graph import GraphLowering
|
|
from torch._inductor.scheduler import SchedulerNode
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.test_operators import realize
|
|
from torch._inductor.utils import is_big_gpu, run_and_get_code, sympy_index_symbol
|
|
from torch._inductor.virtualized import ops, V
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._pytree import tree_map
|
|
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
|
|
|
|
|
# set so that metrics appear
|
|
torch._logging.set_logs(inductor_metrics=True)
|
|
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
|
|
|
|
|
|
if HAS_GPU:
|
|
torch.set_default_device(GPU_TYPE)
|
|
|
|
|
|
class MockScheduler:
|
|
available_buffer_names = ()
|
|
|
|
@staticmethod
|
|
def get_backend(cls, *args):
|
|
return TritonScheduling(cls)
|
|
|
|
def can_buffer_be_removed_through_fusion(self, *args, **kwargs):
|
|
return False
|
|
|
|
|
|
class MockSchedulerTest(TestCase):
|
|
_exit_stack = None
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
gm = torch.fx.symbolic_trace(lambda: 0)
|
|
graph = GraphLowering(gm)
|
|
graph.scheduler = MockScheduler
|
|
cls._exit_stack = contextlib.ExitStack()
|
|
cls._exit_stack.enter_context(V.set_graph_handler(graph))
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
cls._exit_stack.close()
|
|
|
|
|
|
@inductor_config.patch(loop_ordering_after_fusion=True)
|
|
class ImplDetailTest(MockSchedulerTest):
|
|
@staticmethod
|
|
def _get_snode_body_sym_prefix(snode):
|
|
body = snode._body
|
|
prefix = ""
|
|
|
|
for var in body.var_ranges:
|
|
prefix = str(var)[0]
|
|
break
|
|
|
|
assert prefix
|
|
return prefix
|
|
|
|
@staticmethod
|
|
def _create_computed_buffer_ax2(sizes=(32, 64), strides=None):
|
|
"""
|
|
Create a ComputedBuffer for 'a x 2'
|
|
"""
|
|
if strides is None:
|
|
strides = ir.FlexibleLayout.contiguous_strides(sizes)
|
|
|
|
box_a = ir.TensorBox.create(
|
|
ir.Buffer(
|
|
name="a",
|
|
layout=ir.FixedLayout(
|
|
torch.device(GPU_TYPE),
|
|
dtype=torch.float32,
|
|
size=sizes,
|
|
stride=strides,
|
|
),
|
|
)
|
|
)
|
|
box_a_loader = box_a.make_loader()
|
|
|
|
def inner_fn(index):
|
|
return box_a_loader(index) * 2
|
|
|
|
buf = ir.Pointwise.create(
|
|
device=box_a.get_device(),
|
|
dtype=box_a.get_dtype(),
|
|
inner_fn=inner_fn,
|
|
ranges=box_a.get_size(),
|
|
)
|
|
buf.realize()
|
|
computed_buf = buf.data.data
|
|
computed_buf.decide_layout()
|
|
return computed_buf
|
|
|
|
def test_reorder_twice(self):
|
|
"""
|
|
This may happen in practice if we pick a order when fusing A and B.
|
|
Then we pick another order for AB when we fusion C into it.
|
|
|
|
E.g. happens for BertForMaskedLM.
|
|
"""
|
|
|
|
buf = self._create_computed_buffer_ax2()
|
|
snode = SchedulerNode(V.graph.scheduler, buf)
|
|
snode.apply_new_loop_order([1, 0])
|
|
prefix1 = self._get_snode_body_sym_prefix(snode)
|
|
self.assertTrue(prefix1 == "p")
|
|
snode.apply_new_loop_order([1, 0])
|
|
prefix2 = self._get_snode_body_sym_prefix(snode)
|
|
self.assertTrue(prefix2 == "p")
|
|
|
|
def test_reorder_and_merge_loops(self):
|
|
sizes = (1024, 2048)
|
|
strides = (1, 1024)
|
|
buf = self._create_computed_buffer_ax2(sizes, strides)
|
|
old_sizes, old_body = buf.simplify_and_reorder()
|
|
|
|
# Make sure loop reordering happens here
|
|
self.assertTrue(tuple(old_sizes[0]) == tuple(reversed(sizes)), f"{old_sizes=}")
|
|
new_body = old_body.merge_loops()
|
|
new_sizes = new_body.sizes
|
|
self.assertTrue(tuple(new_sizes[0]) == (np.prod(sizes),), f"{new_sizes=}")
|
|
|
|
def test_merge_loops_invalidate_pw_dep_cache(self):
|
|
sizes = (1024, 2048)
|
|
strides = (2048, 1)
|
|
buf = self._create_computed_buffer_ax2(sizes, strides)
|
|
|
|
snode = SchedulerNode(V.graph.scheduler, buf)
|
|
old_var_ranges = snode.pointwise_read_writes().var_ranges
|
|
self.assertTrue(len(old_var_ranges) == 2) # 2 dimension not merged
|
|
snode.merge_loops()
|
|
new_var_ranges = snode.pointwise_read_writes().var_ranges
|
|
|
|
# we cache pointwise_read_writes result on a scheduler node
|
|
# make sure new_var_ranges is refreshed by invalidating the cache.
|
|
self.assertTrue(len(new_var_ranges) == 1) # 2 dimensions get merged
|
|
|
|
def test_reorder_modular_indexing(self):
|
|
"""
|
|
There was a bug that we wrongly map i0 to the dimension with size 49
|
|
when reordering the loop and cause ModularIndexing get optimized away
|
|
as an no-op.
|
|
"""
|
|
|
|
def _create_computed_buffer():
|
|
def inner_fn(index):
|
|
i0, _, i2, i3 = index
|
|
return ops.load(
|
|
"primal", i3 + 49 * i2 + 2401 * ModularIndexing(i0, 1, 64)
|
|
)
|
|
|
|
buf = ir.Pointwise.create(
|
|
device=torch.device(GPU_TYPE),
|
|
dtype=torch.float32,
|
|
inner_fn=inner_fn,
|
|
ranges=[128, 4, 49, 49],
|
|
)
|
|
buf.realize()
|
|
cbuf = buf.data.data
|
|
cbuf.decide_layout()
|
|
return cbuf
|
|
|
|
buf = _create_computed_buffer()
|
|
_, body = buf.simplify_and_reorder()
|
|
new_body = body.reorder_iter_loops([1, 2, 3, 0])
|
|
|
|
z0, z1, z2, z3 = (sympy_index_symbol(f"p{i}") for i in range(4))
|
|
self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49})
|
|
self.assertEqual(
|
|
body.indexing_exprs["index0"],
|
|
z3 + 49 * z2 + 2401 * ModularIndexing(z0, 1, 64),
|
|
)
|
|
self.assertEqual(new_body.var_ranges, {z0: 4, z1: 49, z2: 49, z3: 128})
|
|
self.assertEqual(
|
|
new_body.indexing_exprs["index0"],
|
|
z2 + 49 * z1 + 2401 * ModularIndexing(z3, 1, 64),
|
|
)
|
|
|
|
|
|
@inductor_config.patch(
|
|
{
|
|
"benchmark_kernel": True,
|
|
"loop_ordering_after_fusion": True,
|
|
"triton.unique_kernel_names": True,
|
|
}
|
|
)
|
|
class LoopOrderingTest(TestCase):
|
|
device = GPU_TYPE
|
|
|
|
def do_acc_test(self, f, *args, cast_fp8=True):
|
|
expect = f(*args)
|
|
actual = torch.compile(f)(*args)
|
|
|
|
if cast_fp8:
|
|
|
|
def _cast(x):
|
|
if isinstance(x, torch.Tensor) and x.dtype in (
|
|
torch.float8_e5m2,
|
|
torch.float8_e4m3fn,
|
|
):
|
|
return x.to(torch.float32)
|
|
return x
|
|
|
|
# Wordaround the issue that call allclose on fp8 tensor triggers error
|
|
# RuntimeError: "mul_cuda" not implemented for 'Float8_e4m3fn'
|
|
expect = tree_map(_cast, expect)
|
|
actual = tree_map(_cast, actual)
|
|
self.assertTrue(same(expect, actual, tol=1e-3))
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
metrics.reset()
|
|
|
|
def test_for_reordering_reindex(self):
|
|
"""
|
|
ComputedBuffer.iter_reoredering_reindex can cause some fusion
|
|
opportunitiies being skipped.
|
|
|
|
In this test case, Inductor generates 2 triton kernels before.
|
|
By removing ComputedBuffer.iter_reoredering_reindex, we can fuse those
|
|
two kernels into a single one.
|
|
"""
|
|
|
|
def f(x, y):
|
|
"""
|
|
Add a matmul since inductor may force layout for output.
|
|
"""
|
|
return (x.sum(dim=-1) + 1) @ y
|
|
|
|
A, B = 20, 30
|
|
# Make the first 2 dimension not able to merge on purpose so that
|
|
# ComputedBuffer.iter_reoredering_reindex will be updated.
|
|
x = rand_strided([A, A, B], [B, B * A + 300, 1], device=GPU_TYPE)
|
|
y = torch.randn(A, A)
|
|
|
|
self.do_acc_test(f, x, y)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
expected_num_bytes = 0
|
|
expected_num_bytes += A * A * B + A * A # for the fused reduction
|
|
expected_num_bytes += A * A * 3 # for matmul
|
|
expected_num_bytes *= x.itemsize
|
|
self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed)
|
|
|
|
def test_apbt_realize(self):
|
|
M = 1024
|
|
N = 2048
|
|
|
|
def f(x, y):
|
|
"""
|
|
There will be 2 kernels being generated without loop ordering after fusion:
|
|
https://gist.github.com/shunting314/44df83f71de2c110232c50ac6638ed69
|
|
"""
|
|
x = realize(x * 2)
|
|
y = realize(y * 3)
|
|
return x + y
|
|
|
|
x = torch.randn(M, N)
|
|
y = torch.randn(N, M).t()
|
|
|
|
self.do_acc_test(f, x, y)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
def test_sum_and_t(self):
|
|
N = 1024
|
|
|
|
def f(x):
|
|
return x.sum(dim=-1), x.t().contiguous()
|
|
|
|
x = torch.randn(N, N * 2)
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
def test_pw_outer_red(self):
|
|
def f(x):
|
|
x = realize(x + 1)
|
|
return x.sum(dim=[0, 1])
|
|
|
|
# make the first 2 dimension small so we don't split the reduction
|
|
x = torch.randn(2, 4, 512)
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
def test_pw_outer_red_2(self):
|
|
"""
|
|
The pointwise kernel is a fused kernel
|
|
"""
|
|
|
|
def f(x):
|
|
x = realize(x + 1)
|
|
x = realize(x - 2)
|
|
x = realize(x * 3)
|
|
return x.sum(dim=[0, 1])
|
|
|
|
# make the first 2 dimension small so we don't split the reduction
|
|
x = torch.randn(2, 4, 512)
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
@inductor_config.patch(split_reductions=False)
|
|
def test_different_reduction_order(self):
|
|
"""
|
|
We should not reorder loops in this case. Since reordering loops does
|
|
not help!
|
|
"""
|
|
|
|
def f(x):
|
|
return x.sum(dim=0), x.sum(dim=1)
|
|
|
|
x = torch.randn(1024, 2048)
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(2, metrics.generated_kernel_count)
|
|
self.assertEqual(0, metrics.num_loop_reordering)
|
|
|
|
def test_keep_fake_dep(self):
|
|
"""
|
|
In this model, there are fake dependencies (StarDep) between Scatter
|
|
and a following mutation kernel that computes the gradients of
|
|
the embedding tables.
|
|
|
|
When we do loop reordering for the mutation kernel, we re-analyze
|
|
the node's dependencies. But the analysis result does not contains
|
|
those fake dependencies. Have to add them back manually.
|
|
"""
|
|
V = 2048
|
|
hidden_size = 64
|
|
max_seqlen = 512
|
|
batch_size = 8
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(V, hidden_size)
|
|
self.position_embeddings = nn.Embedding(max_seqlen, hidden_size)
|
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
|
|
|
def forward(self, input_ids, labels, position_ids):
|
|
emb = self.word_embeddings(input_ids) + self.position_embeddings(
|
|
position_ids
|
|
)
|
|
return self.layer_norm(emb)
|
|
|
|
m = Model()
|
|
|
|
@torch.compile
|
|
def f(*args):
|
|
m(*args).sum().backward()
|
|
|
|
input_ids = torch.randint(0, V, (batch_size, max_seqlen))
|
|
labels = torch.randint(0, V, (batch_size, max_seqlen))
|
|
position_ids = torch.arange(max_seqlen)[None, :]
|
|
# Make sure this line does not raise exceptions. If we miss
|
|
# fake dependencies after loop reordering, we may get exception that
|
|
# some buffer is used before being defined.
|
|
f(input_ids, labels, position_ids)
|
|
|
|
def test_different_broadcast_shapes(self):
|
|
def f(x, y, c):
|
|
return x + c, y + c
|
|
|
|
x = torch.randn(4, 256, 1024)
|
|
y = torch.randn(2, 512, 1024)
|
|
c = torch.randn(1024)
|
|
self.do_acc_test(f, x, y, c)
|
|
|
|
# The two kernels are not fused due to c is broadcasted
|
|
self.assertEqual(2, metrics.generated_kernel_count)
|
|
|
|
def test_view(self):
|
|
"""
|
|
Passing this test relies that we compare normalized MemoryDep.
|
|
Normlaization here means merging contiguous loops.
|
|
|
|
To make loop reordering work, we don't merge loops when creating
|
|
SchedulerNode. Thus we need explicitly normalize MemoryDep when
|
|
we check if two MemeoryDep matches.
|
|
"""
|
|
|
|
def f(x):
|
|
y = x.sin()
|
|
x = realize(x.view(10, 10))
|
|
return x, y
|
|
|
|
x = torch.randn(100)
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
|
def test_fp8_cast_and_t(self):
|
|
"""
|
|
This test repros the not able to fuses issue in
|
|
https://github.com/pytorch/pytorch/issues/130015
|
|
for fp8 cast and transpose
|
|
"""
|
|
|
|
def f(x, scale):
|
|
x = x * scale
|
|
x = x.clamp(-1 * E4M3_MAX_POS, E4M3_MAX_POS)
|
|
x = x.to(torch.float8_e4m3fn)
|
|
x_t = x.t().contiguous().t()
|
|
return x, x_t
|
|
|
|
x = torch.randn(4096, 4096, dtype=torch.bfloat16)
|
|
scale = torch.Tensor([10.0]).to(GPU_TYPE)
|
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
|
|
|
self.do_acc_test(f, x, scale)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
|
def test_fp8_pattern_2(self):
|
|
"""
|
|
This test repros the fp8 fusion relation issue here:
|
|
https://github.com/pytorch/pytorch/issues/133242
|
|
"""
|
|
ref_dtype = torch.bfloat16
|
|
M, K = 4096, 4096
|
|
|
|
input_tensor = torch.randn(
|
|
M, K, device=GPU_TYPE, dtype=ref_dtype, requires_grad=False
|
|
)
|
|
scale = torch.Tensor([10.0]).to(GPU_TYPE)
|
|
|
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
|
|
|
def test_pattern2(tensor_x_inp, scale_x):
|
|
tensor_x = tensor_x_inp * scale_x
|
|
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
|
tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)
|
|
|
|
tensor_x_t = (tensor_x_inp * scale_x).t()
|
|
tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
|
tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn)
|
|
|
|
tensor_fp8_t = tensor_fp8_t.contiguous().t()
|
|
|
|
return (tensor_fp8, tensor_fp8_t)
|
|
|
|
test_pattern = torch.compile(test_pattern2)
|
|
tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
|
|
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
expected_numbytes = scale.nbytes # scalar
|
|
expected_numbytes += input_tensor.nbytes # input
|
|
expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output
|
|
self.assertEqual(expected_numbytes, metrics.num_bytes_accessed)
|
|
|
|
def test_outer_dimension_softmax(self):
|
|
"""
|
|
This test repros the not able to fuse problem for outer dimension
|
|
softmax reported here: https://github.com/pytorch/pytorch/issues/93718
|
|
|
|
Perf data on h100:
|
|
- without loop ordering after fusion 0.564 ms
|
|
- with loop ordering after fusion 0.302 ms
|
|
This is 1.87x speedup.
|
|
|
|
"""
|
|
x = torch.randn(32, 2**21, device=GPU_TYPE)
|
|
|
|
def f(x):
|
|
return F.softmax(x, dim=0)
|
|
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
def test_outer_dimension_sum_fuse_with_pw(self):
|
|
"""
|
|
Test the fusion of an outer dimension sum with a followed pointwise.
|
|
Perf data on h100:
|
|
- without loop ordering after fusion 0.436 ms
|
|
- with loop ordering after fusion 0.260 ms
|
|
This is 1.68x speedup.
|
|
"""
|
|
x = torch.randn(32, 2**21, device=GPU_TYPE)
|
|
|
|
def f(x):
|
|
return x.sum(dim=0, keepdim=True) + x
|
|
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(1, metrics.generated_kernel_count)
|
|
|
|
if DO_PERF_TEST:
|
|
from triton.testing import do_bench
|
|
|
|
optf = torch.compile(f)
|
|
print(f"ms={do_bench(lambda: optf(x))}")
|
|
|
|
# Disable split reduction to make it easier to calculate the expected
|
|
# number of bytes accessed. In this case, split reduction does not
|
|
# help perf much.
|
|
@inductor_config.patch(split_reductions=False)
|
|
def test_fuse_reduction_with_tiled_pw(self):
|
|
def f(x):
|
|
y = torch.sum(torch.sum(x, dim=-1))
|
|
|
|
z = x / 10.0
|
|
z_t = z.t().contiguous().t()
|
|
return y, z, z_t
|
|
|
|
# use this input sizes to test for perf
|
|
if DO_PERF_TEST:
|
|
M, N = 1024 * 32, 1024 * 8
|
|
else:
|
|
M, N = 200, 100
|
|
x = torch.randn(M, N, device=GPU_TYPE)
|
|
actual = f(x)
|
|
opt_f = torch.compile(f)
|
|
expected = opt_f(x)
|
|
self.assertTrue(same(actual, expected, tol=1e-3))
|
|
|
|
# We should fuse the first sum with the two pointwise.
|
|
# Overall we read x once for all these three kernels and write
|
|
# out 2 buffers with the same size as x.
|
|
# This should be sort of 'optimal' for this workload.
|
|
expected_numbytes = x.nbytes * 3
|
|
|
|
# A small amount of extra memory access for:
|
|
# - store output for the first reduction
|
|
# - load input for the second redution
|
|
# - store output for the second reduction
|
|
expected_numbytes += (M * 2 + 1) * x.itemsize
|
|
|
|
print(expected_numbytes)
|
|
self.assertEqual(expected_numbytes, metrics.num_bytes_accessed)
|
|
|
|
if DO_PERF_TEST:
|
|
from triton.testing import do_bench
|
|
|
|
ms = do_bench(lambda: opt_f(x))
|
|
print(f"{ms=:.3f}")
|
|
|
|
@inductor_config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
"test_configs.max_mm_configs": 4,
|
|
}
|
|
)
|
|
@skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune")
|
|
def test_interaction_with_triton_template(self):
|
|
"""
|
|
Make sure the dependency prefix for TritonTempalate and its
|
|
prologue match.
|
|
"""
|
|
|
|
@torch.compile
|
|
def f(x, y):
|
|
return (x.expand([1, y.shape[0]]) + 1) @ y
|
|
|
|
x = torch.randn([1, 1], device=GPU_TYPE)
|
|
y = torch.randn([64, 128], device=GPU_TYPE)
|
|
|
|
out, code = run_and_get_code(f, x, y)
|
|
|
|
# well when benchmark_kernel flag is on, we have one more .run
|
|
# call in the benchmarking code.
|
|
FileCheck().check("def call(").check_count(
|
|
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
|
|
).run(code[0])
|
|
|
|
@inductor_config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
"test_configs.max_mm_configs": 4,
|
|
}
|
|
)
|
|
@skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune")
|
|
def test_interaction_with_multi_template(self):
|
|
"""
|
|
Skip MultiTemplateBuffer during loop reordering
|
|
"""
|
|
|
|
@torch.compile
|
|
def f(x, y):
|
|
return (x @ y), x + 1
|
|
|
|
N = 2
|
|
x = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
|
|
y = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
|
|
|
|
out, code = run_and_get_code(f, x, y)
|
|
# didn't fuse due to small savings
|
|
FileCheck().check_count("@triton.jit", 2, exactly=True).run(code[0])
|
|
|
|
def test_fuse_with_scalar_shared_memory(self):
|
|
"""
|
|
Make sure if we can fuse two nodes sharing a scalar before,
|
|
we can still do it with LOAF applied.
|
|
|
|
This is not really a big deal. But some tests rely on this and
|
|
less number of kernels has some small benefits.
|
|
"""
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
return torch.mean(x)
|
|
|
|
x = torch.randn([5, 5], device=GPU_TYPE)
|
|
out, code = run_and_get_code(f, x)
|
|
FileCheck().check_count("@triton.jit", 1, exactly=True).run(code[0])
|
|
|
|
def test_3dred_pw_2d_outer_red(self):
|
|
"""
|
|
Test a pattern as follows. We have a 3d contiguous tensor [m, n, k] as input.
|
|
1. do reduction on the k dimension and get a [m, n] tensor
|
|
2. do a pointwise operation on this [m, n] tensor (and realize the computation)
|
|
3. do a outer reduction on the output of step 2 on the m dimension.
|
|
|
|
Each of these step generate a kernel before fusion.
|
|
Without any loop reorder, kernel 1 and kernel 2 will get fused. And kernel 3 will be separeate.
|
|
|
|
But if we reorder the loop for kernel 2, then kernel 2 will get fused with kernel 3.
|
|
And the fused kernel-2-3 can not be fused with kernel 1.
|
|
|
|
The older version of LOAF algorithm will do reorder in this case. But there is no real
|
|
benefits. There are even some slight downsides
|
|
1. the original fusion without loop reordering is more natural
|
|
2. fusion kernel 1 with kernel 2 may help precision when the output of kernel 1 is in low precision.
|
|
By fusion kernel 1 and kernel 2, the pointwise operation will operate on fp32 precision thanks
|
|
to fusion.
|
|
"""
|
|
M, N, K = 64, 64, 64
|
|
|
|
def f(x):
|
|
x = x.sum(dim=-1)
|
|
x = x + 1 # can be more complex like sigmoid or other ops
|
|
return x, x.sum(dim=0)
|
|
|
|
x = torch.randn(M, N, K, device=GPU_TYPE)
|
|
self.do_acc_test(f, x)
|
|
self.assertEqual(0, metrics.num_loop_reordering)
|
|
|
|
|
|
@inductor_config.patch(
|
|
{
|
|
"triton.unique_kernel_names": True,
|
|
"loop_ordering_after_fusion": True,
|
|
"triton.max_tiles": 3,
|
|
"triton.coalesce_tiling_analysis": True,
|
|
}
|
|
)
|
|
@instantiate_parametrized_tests
|
|
class MemoryCoalescingTest(MockSchedulerTest):
|
|
"""Tests for memory coalescing analysis with specific tensor sizes."""
|
|
|
|
device = GPU_TYPE
|
|
_exit_stack = None
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
metrics.reset()
|
|
|
|
def _create_buffer(self, name, sizes):
|
|
"""Create a buffer with specified sizes"""
|
|
|
|
strides = ir.FlexibleLayout.contiguous_strides(sizes)
|
|
|
|
box = ir.TensorBox.create(
|
|
ir.Buffer(
|
|
name=name,
|
|
layout=ir.FixedLayout(
|
|
torch.device(self.device),
|
|
dtype=torch.float32,
|
|
size=sizes,
|
|
stride=strides,
|
|
),
|
|
)
|
|
)
|
|
box_loader = box.make_loader()
|
|
|
|
def inner_fn(index):
|
|
return box_loader(index) * 2
|
|
|
|
buf = ir.Pointwise.create(
|
|
device=box.get_device(),
|
|
dtype=box.get_dtype(),
|
|
inner_fn=inner_fn,
|
|
ranges=box.get_size(),
|
|
)
|
|
buf.realize()
|
|
computed_buf = buf.data.data
|
|
computed_buf.decide_layout()
|
|
|
|
return computed_buf
|
|
|
|
def _create_scheduler_node(self, buf):
|
|
s = SchedulerNode(V.graph.scheduler, buf)
|
|
s.min_order = 0
|
|
s.max_order = 100
|
|
return s
|
|
|
|
@parametrize(
|
|
"inps",
|
|
(
|
|
((128, 384, 196), (768, 64, 196), (128, 6, 64, 196)),
|
|
((64,), (16, 4), (16, 4)),
|
|
((5, 6), (3, 10), (30,)),
|
|
((5, 6, 20), (3, 10, 20), (30, 20)),
|
|
),
|
|
)
|
|
def test_inferred_splits(self, inps):
|
|
"""
|
|
Test memory coalescing analysis with the specified tensor sizes.
|
|
Using direct SchedulerNode creation with sizes (128, 384, 196) and (768, 64, 196).
|
|
"""
|
|
|
|
s1, s2, expected_size = inps
|
|
|
|
# Create buffers with the specified sizes
|
|
buf1 = self._create_buffer("buffer1", s1)
|
|
buf2 = self._create_buffer("buffer2", s2)
|
|
|
|
# Create scheduler nodes
|
|
snode1 = self._create_scheduler_node(buf1)
|
|
snode2 = self._create_scheduler_node(buf2)
|
|
|
|
# Create a fused node
|
|
fused_node = torch._inductor.scheduler.FusedSchedulerNode.fuse(snode1, snode2)
|
|
|
|
from torch._inductor import tiling_utils
|
|
|
|
fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(fused_node)
|
|
|
|
var_ranges = fused_norm_read_writes.var_ranges
|
|
self.assertEqual(list(var_ranges.values()), list(expected_size))
|
|
|
|
def test_remapped_reads(self):
|
|
from torch._inductor import tiling_utils
|
|
|
|
def fn(nodes):
|
|
assert len(nodes) == 1
|
|
fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(
|
|
nodes[0]
|
|
)
|
|
|
|
self.assertTrue(len(fused_norm_read_writes.var_ranges) == 2)
|
|
|
|
# both reads remapped correctly
|
|
FileCheck().check("4*n0 + n1").run(
|
|
repr(fused_norm_read_writes.reads.keys())
|
|
)
|
|
FileCheck().check("n0 + 4*n1").run(
|
|
repr(fused_norm_read_writes.reads.keys())
|
|
)
|
|
|
|
return nodes
|
|
|
|
with torch._inductor.config.patch(_post_fusion_custom_pass=fn):
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return x + y
|
|
|
|
foo(
|
|
torch.rand([4, 4], device=GPU_TYPE),
|
|
torch.rand([4, 4], device=GPU_TYPE).T,
|
|
)
|
|
|
|
def test_remapped_reads_split(self):
|
|
from torch._inductor import tiling_utils
|
|
|
|
def fn(nodes):
|
|
self.assertTrue(len(nodes) == 1)
|
|
fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(
|
|
nodes[0]
|
|
)
|
|
|
|
inp_node_reads = nodes[0].get_nodes()[1]._body.get_read_exprs()
|
|
node_ranges = nodes[0].get_nodes()[1]._body.var_ranges
|
|
self.assertTrue(len(node_ranges) == 1)
|
|
self.assertTrue(next(iter(node_ranges.values())) == 36)
|
|
var = next(iter(node_ranges.keys()))
|
|
|
|
r = FloorDiv(var, 6) + 6 * ModularIndexing(var, 1, 6)
|
|
self.assertTrue(r in inp_node_reads)
|
|
|
|
# mapped reads
|
|
self.assertTrue(list(fused_norm_read_writes.var_ranges.values()) == [6, 6])
|
|
n0, n1 = list(fused_norm_read_writes.var_ranges.keys())
|
|
|
|
# translation of above is n0 + 6 * n1
|
|
self.assertTrue((n0 + 6 * n1) in fused_norm_read_writes.reads.keys())
|
|
|
|
return nodes
|
|
|
|
with torch._inductor.config.patch(_post_fusion_custom_pass=fn):
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return (
|
|
x + y
|
|
).contiguous().flatten() + torch.ops._inductor_test.realize(
|
|
(y.T + 1).flatten()
|
|
)
|
|
|
|
foo(
|
|
torch.rand([6, 6], device=GPU_TYPE),
|
|
torch.rand([6, 6], device=GPU_TYPE).T,
|
|
)
|
|
|
|
def test_reduction_pointwise(self):
|
|
# test one pw var, one red var
|
|
from torch._inductor import tiling_utils
|
|
|
|
def fn(nodes):
|
|
self.assertTrue(len(nodes) == 1)
|
|
fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0])
|
|
|
|
i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars
|
|
self.assertTrue(len(i_vars) == 1)
|
|
self.assertTrue(len(r_vars) == 1)
|
|
|
|
# single write to index var
|
|
self.assertTrue(
|
|
fused_rw.index_vars[0] == next(iter(fused_rw.writes.keys()))
|
|
)
|
|
|
|
# the write to the fused intermediary node should be removed
|
|
self.assertTrue(len(fused_rw.writes) == 1)
|
|
|
|
# single read
|
|
self.assertTrue(len(fused_rw.reads) == 1)
|
|
# that is applied to two bufs
|
|
self.assertTrue(len(next(iter(fused_rw.reads.values()))) == 2)
|
|
|
|
# and the read should be in terms of the index + reduce var,
|
|
# even though node is pointwise
|
|
self.assertTrue(256 * i_vars[0] + r_vars[0] in fused_rw.reads)
|
|
|
|
return nodes
|
|
|
|
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
out = torch.ops._inductor_test.realize(x + y)
|
|
return out.sum(dim=1)
|
|
|
|
foo(
|
|
torch.rand(256, 256, device=GPU_TYPE),
|
|
torch.rand(256, 256, device=GPU_TYPE),
|
|
)
|
|
|
|
def test_reduction_no_pointwise(self):
|
|
# test one pw var, one red var
|
|
from torch._inductor import tiling_utils
|
|
|
|
def fn(nodes):
|
|
self.assertTrue(len(nodes) == 1)
|
|
fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0])
|
|
|
|
i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars
|
|
self.assertTrue(len(i_vars) == 0)
|
|
self.assertTrue(len(r_vars) == 1)
|
|
|
|
return nodes
|
|
|
|
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
|
|
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x.sum()
|
|
|
|
foo(torch.rand(1024, device=GPU_TYPE))
|
|
|
|
def test_coalescing(self):
|
|
from torch._inductor import tiling_utils
|
|
|
|
# Define symbolic variables
|
|
i, j, n, m = sympy.symbols("i j n m", integer=True)
|
|
|
|
# Test cases: (expression, var_ranges, expected_result)
|
|
test_cases = [
|
|
# Simple direct case
|
|
(i + j * 5, {i: 10, j: 8}, i),
|
|
# Floor division case
|
|
(i + FloorDiv(j, 2), {i: 4, j: 8}, i),
|
|
# Modular indexing
|
|
(i * 10 + ModularIndexing(j, 1, 3), {i: 5, j: 10}, j),
|
|
# Case with no coalescing variable
|
|
(i * 2 + j * 3, {i: 8, j: 5}, None),
|
|
# Division case
|
|
(i / 2, {i: 10}, None),
|
|
# More complex floor division
|
|
(j + FloorDiv(i, 3), {i: 6, j: 12}, j),
|
|
# Addition inside modular indexing
|
|
(ModularIndexing(i + 3, 1, 6), {i: 8, j: 12}, i),
|
|
]
|
|
|
|
for expr, var_ranges, expected in test_cases:
|
|
# Test the function
|
|
result = tiling_utils.find_coalesced_var(expr, var_ranges)
|
|
self.assertEqual(result, expected)
|
|
|
|
@parametrize("downcast_transposed_v", (False, True))
|
|
def test_tiled_coalesce_analysis(self, downcast_transposed_v):
|
|
# test one pw var, one red var
|
|
from torch._inductor import tiling_utils
|
|
|
|
def fn(nodes):
|
|
self.assertTrue(len(nodes) == 1)
|
|
|
|
coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0])
|
|
|
|
i_vars = coalesce_analysis.norm_read_writes.index_vars
|
|
|
|
# because output is contiguous, second dimension should
|
|
# coalesce twice as many bytes as first dimension
|
|
# if not downcasted
|
|
# if downcasted, should be equal, bc larger dtype size
|
|
# we also weight writes x 2
|
|
cont_reads = coalesce_analysis.coalesced_by_var[i_vars[1]]
|
|
t_reads = coalesce_analysis.coalesced_by_var[i_vars[0]]
|
|
|
|
if not downcast_transposed_v:
|
|
self.assertEqual(cont_reads, t_reads * 3)
|
|
else:
|
|
self.assertEqual(cont_reads, t_reads * 1.5)
|
|
|
|
return nodes
|
|
|
|
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return x + y.to(x.dtype)
|
|
|
|
y_dtype = torch.float if not downcast_transposed_v else torch.float64
|
|
foo(
|
|
torch.rand(256, 256, device=GPU_TYPE),
|
|
torch.rand(256, 256, device=GPU_TYPE, dtype=y_dtype).T,
|
|
)
|
|
|
|
def test_solve_for_zero(self):
|
|
from torch._inductor import tiling_utils
|
|
|
|
x, y = sympy.symbols("x y", integer=True)
|
|
# Test cases: (expression, expected_result)
|
|
test_cases = [
|
|
# Simple linear expressions
|
|
(x + 5, (-5)),
|
|
(2 * x - 10, (5)),
|
|
# Constant expressions (should return None)
|
|
(sympy.Integer(7), None),
|
|
(sympy.Integer(0), None),
|
|
# FloorDiv cases (should return None per function)
|
|
(FloorDiv(x, 2), None),
|
|
(FloorDiv(x, 2) + 5, None),
|
|
# ModularIndexing cases
|
|
(ModularIndexing(x, 1, 5), (5)),
|
|
(ModularIndexing(x, 1, 3), (3)),
|
|
# Expressions with no constant solution
|
|
(x**2 + 1, None), # No real solution
|
|
]
|
|
for expr, expected in test_cases:
|
|
result = tiling_utils.solve_for_zero(expr)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_solve_for_tiling(self):
|
|
from torch._inductor import tiling_utils
|
|
|
|
x = sympy.Symbol("x", integer=True)
|
|
|
|
test_cases = [
|
|
# Simple linear cases that coalesce
|
|
(3 * x, None),
|
|
# # # # Expression with no free symbols
|
|
# (sympy.Integer(5), None),
|
|
(x / 3, 3),
|
|
(FloorDiv(x * 2, 6), 3),
|
|
# # ModularIndexing expressions
|
|
(ModularIndexing(FloorDiv(x, 4), 1, 64), 4),
|
|
(x + ModularIndexing(x, 1, 5), None),
|
|
(x**2, None), # Non-linear, diff is not constant
|
|
(4096 * (ModularIndexing(32 * x, 1, 2048)) + FloorDiv(x, 64), 64),
|
|
(4096 * (ModularIndexing(x, 1, 2048)) + FloorDiv(x, 2048), 2048),
|
|
]
|
|
|
|
for expr, expected in test_cases:
|
|
result = tiling_utils.solve_for_tiling(expr)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_induced_fused_tiling(self):
|
|
from torch._inductor import tiling_utils
|
|
|
|
def fn(nodes):
|
|
self.assertTrue(len(nodes) == 1)
|
|
|
|
coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0])
|
|
self.assertEqual(coalesce_analysis.suggested_split.tiling_factor, 64)
|
|
return nodes
|
|
|
|
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
|
|
|
|
def forward(permute):
|
|
clone = torch.ops.aten.clone.default(
|
|
permute, memory_format=torch.contiguous_format
|
|
)
|
|
view_2 = torch.ops.aten.view.default(clone, [-1, 32])
|
|
amax_1 = torch.ops.aten.amax.default(view_2, [1])
|
|
return amax_1
|
|
|
|
XDIM = 2048
|
|
YDIM = 4096
|
|
|
|
arg0_1 = torch.randn([XDIM, YDIM], device=GPU_TYPE, dtype=torch.bfloat16)
|
|
permute = torch.ops.aten.permute.default(arg0_1, [1, 0])
|
|
|
|
out, code = run_and_get_code(torch.compile(forward), (permute))
|
|
|
|
self.assertEqual(out, forward(permute))
|
|
FileCheck().check("YBLOCK").check("XBLOCK").run(code[0])
|
|
|
|
|
|
layouts = ("cont", "NHWC", "T")
|
|
|
|
|
|
@inductor_config.patch(
|
|
{
|
|
"triton.unique_kernel_names": True,
|
|
"loop_ordering_after_fusion": True,
|
|
"triton.coalesce_tiling_analysis": True,
|
|
}
|
|
)
|
|
@instantiate_parametrized_tests
|
|
class TestTiling(TestCase):
|
|
def T(self, layout: str):
|
|
SIZE_A = 128
|
|
SIZE_B = 256
|
|
SIZE_C = 512
|
|
|
|
if layout == "cont":
|
|
return torch.rand(SIZE_A, SIZE_B, SIZE_C, device=GPU_TYPE).unsqueeze(0)
|
|
elif layout == "T":
|
|
return (
|
|
torch.rand(SIZE_A, SIZE_B, SIZE_C, device=GPU_TYPE)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
.transpose(1, 2)
|
|
.unsqueeze(0)
|
|
)
|
|
else:
|
|
assert layout == "NHWC"
|
|
return torch.rand([1, SIZE_A, SIZE_B, SIZE_C], device=GPU_TYPE).to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
|
|
@parametrize("a", layouts)
|
|
@parametrize("b", layouts)
|
|
def test_pointwise(self, a, b):
|
|
def foo(x, y):
|
|
return x + y
|
|
|
|
x, y = self.T(a), self.T(b)
|
|
res, code = run_and_get_code(torch.compile(foo), x, y)
|
|
|
|
if a != b:
|
|
FileCheck().check("ynumel").run(code[0])
|
|
else:
|
|
FileCheck().check_not("ynumel").run(code[0])
|
|
|
|
self.assertEqual(res, foo(x, y))
|
|
|
|
def test_tiled_reduction(self):
|
|
def f(a, b):
|
|
return (a * b).sum(dim=-1)
|
|
|
|
N = 512
|
|
inps = (
|
|
torch.randn(N, N, N, device=GPU_TYPE).permute(2, 1, 0),
|
|
torch.randn(N, N, N, device=GPU_TYPE).permute(1, 2, 0),
|
|
)
|
|
f_c = torch.compile(f)
|
|
out, code = run_and_get_code(f_c, *inps)
|
|
|
|
FileCheck().check_dag("xnumel = 512").check_dag("ynumel = 512").check_dag(
|
|
"rnumel"
|
|
).run(code[0])
|
|
self.assertEqual(out, f(*inps), atol=0.001, rtol=0.04)
|
|
|
|
def test_3d_pointwise(self):
|
|
inps = (self.T("cont"), self.T("T"), self.T("NHWC"))
|
|
|
|
def f(x, y, z):
|
|
return x + y + z
|
|
|
|
f_c = torch.compile(f)
|
|
out, code = run_and_get_code(f_c, *inps)
|
|
|
|
FileCheck().check_dag("znumel").check_dag("ynumel").check_dag("xnumel").run(
|
|
code[0]
|
|
)
|
|
self.assertEqual(out, f(*inps))
|
|
|
|
def test_cat(self):
|
|
# test unwrapping Identity
|
|
|
|
def f(x, y):
|
|
return torch.cat((x, y)) + 1
|
|
|
|
x = self.T("cont")
|
|
y = self.T("T")
|
|
|
|
inps = (x, y)
|
|
|
|
f_c = torch.compile(f)
|
|
out, code = run_and_get_code(f_c, *inps)
|
|
FileCheck().check_dag("ynumel").check_dag("xnumel").run(code[0])
|
|
self.assertEqual(out, f(*inps))
|
|
|
|
def test_penalized_small_dim(self):
|
|
x = torch.rand([2000, 1], device=GPU_TYPE)
|
|
y = torch.rand([4, 1], device=GPU_TYPE).T
|
|
|
|
# don't tile when it doesn't affect total coalesced mem accesses much
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
inps = (x, y)
|
|
|
|
f_c = torch.compile(f)
|
|
out, code = run_and_get_code(f_c, *inps)
|
|
FileCheck().check_not("ynumel").check_dag("xnumel").run(code[0])
|
|
self.assertEqual(out, f(*inps))
|
|
|
|
def test_mutation_deps(self):
|
|
def f(x):
|
|
return x.add_(1)
|
|
|
|
x = self.T("cont")
|
|
|
|
from torch._inductor import tiling_utils
|
|
|
|
def fn(nodes):
|
|
self.assertTrue(len(nodes) == 1)
|
|
|
|
coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0])
|
|
assert coalesce_analysis is not None
|
|
|
|
reads = coalesce_analysis.norm_read_writes.reads
|
|
writes = coalesce_analysis.norm_read_writes.writes
|
|
|
|
self.assertTrue(len(reads) == 1 and len(writes) == 1)
|
|
self.assertEqual(
|
|
list(coalesce_analysis.norm_read_writes.reads.values()),
|
|
[OrderedSet(("arg0_1",))],
|
|
)
|
|
self.assertEqual(
|
|
list(coalesce_analysis.norm_read_writes.writes.values()),
|
|
[OrderedSet(("buf1",))],
|
|
)
|
|
|
|
return nodes
|
|
|
|
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
|
|
torch.compile(f)(x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if HAS_GPU:
|
|
run_tests()
|