# Owner(s): ["module: inductor"] import contextlib import os import unittest from unittest import skipUnless import numpy as np import sympy import torch import torch.nn.functional as F from torch import nn from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same from torch._inductor import config as inductor_config, ir, metrics from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.graph import GraphLowering from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_operators import realize from torch._inductor.utils import is_big_gpu, run_and_get_code, sympy_index_symbol from torch._inductor.virtualized import ops, V from torch.testing import FileCheck from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map from torch.utils._sympy.functions import FloorDiv, ModularIndexing # set so that metrics appear torch._logging.set_logs(inductor_metrics=True) DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" if HAS_GPU: torch.set_default_device(GPU_TYPE) class MockScheduler: available_buffer_names = () @staticmethod def get_backend(cls, *args): return TritonScheduling(cls) def can_buffer_be_removed_through_fusion(self, *args, **kwargs): return False class MockSchedulerTest(TestCase): _exit_stack = None @classmethod def setUpClass(cls): super().setUpClass() gm = torch.fx.symbolic_trace(lambda: 0) graph = GraphLowering(gm) graph.scheduler = MockScheduler cls._exit_stack = contextlib.ExitStack() cls._exit_stack.enter_context(V.set_graph_handler(graph)) @classmethod def tearDownClass(cls): super().tearDownClass() cls._exit_stack.close() @inductor_config.patch(loop_ordering_after_fusion=True) class ImplDetailTest(MockSchedulerTest): @staticmethod def _get_snode_body_sym_prefix(snode): body = snode._body prefix = "" for var in body.var_ranges: prefix = str(var)[0] break assert prefix return prefix @staticmethod def _create_computed_buffer_ax2(sizes=(32, 64), strides=None): """ Create a ComputedBuffer for 'a x 2' """ if strides is None: strides = ir.FlexibleLayout.contiguous_strides(sizes) box_a = ir.TensorBox.create( ir.Buffer( name="a", layout=ir.FixedLayout( torch.device(GPU_TYPE), dtype=torch.float32, size=sizes, stride=strides, ), ) ) box_a_loader = box_a.make_loader() def inner_fn(index): return box_a_loader(index) * 2 buf = ir.Pointwise.create( device=box_a.get_device(), dtype=box_a.get_dtype(), inner_fn=inner_fn, ranges=box_a.get_size(), ) buf.realize() computed_buf = buf.data.data computed_buf.decide_layout() return computed_buf def test_reorder_twice(self): """ This may happen in practice if we pick a order when fusing A and B. Then we pick another order for AB when we fusion C into it. E.g. happens for BertForMaskedLM. """ buf = self._create_computed_buffer_ax2() snode = SchedulerNode(V.graph.scheduler, buf) snode.apply_new_loop_order([1, 0]) prefix1 = self._get_snode_body_sym_prefix(snode) self.assertTrue(prefix1 == "p") snode.apply_new_loop_order([1, 0]) prefix2 = self._get_snode_body_sym_prefix(snode) self.assertTrue(prefix2 == "p") def test_reorder_and_merge_loops(self): sizes = (1024, 2048) strides = (1, 1024) buf = self._create_computed_buffer_ax2(sizes, strides) old_sizes, old_body = buf.simplify_and_reorder() # Make sure loop reordering happens here self.assertTrue(tuple(old_sizes[0]) == tuple(reversed(sizes)), f"{old_sizes=}") new_body = old_body.merge_loops() new_sizes = new_body.sizes self.assertTrue(tuple(new_sizes[0]) == (np.prod(sizes),), f"{new_sizes=}") def test_merge_loops_invalidate_pw_dep_cache(self): sizes = (1024, 2048) strides = (2048, 1) buf = self._create_computed_buffer_ax2(sizes, strides) snode = SchedulerNode(V.graph.scheduler, buf) old_var_ranges = snode.pointwise_read_writes().var_ranges self.assertTrue(len(old_var_ranges) == 2) # 2 dimension not merged snode.merge_loops() new_var_ranges = snode.pointwise_read_writes().var_ranges # we cache pointwise_read_writes result on a scheduler node # make sure new_var_ranges is refreshed by invalidating the cache. self.assertTrue(len(new_var_ranges) == 1) # 2 dimensions get merged def test_reorder_modular_indexing(self): """ There was a bug that we wrongly map i0 to the dimension with size 49 when reordering the loop and cause ModularIndexing get optimized away as an no-op. """ def _create_computed_buffer(): def inner_fn(index): i0, _, i2, i3 = index return ops.load( "primal", i3 + 49 * i2 + 2401 * ModularIndexing(i0, 1, 64) ) buf = ir.Pointwise.create( device=torch.device(GPU_TYPE), dtype=torch.float32, inner_fn=inner_fn, ranges=[128, 4, 49, 49], ) buf.realize() cbuf = buf.data.data cbuf.decide_layout() return cbuf buf = _create_computed_buffer() _, body = buf.simplify_and_reorder() new_body = body.reorder_iter_loops([1, 2, 3, 0]) z0, z1, z2, z3 = (sympy_index_symbol(f"p{i}") for i in range(4)) self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49}) self.assertEqual( body.indexing_exprs["index0"], z3 + 49 * z2 + 2401 * ModularIndexing(z0, 1, 64), ) self.assertEqual(new_body.var_ranges, {z0: 4, z1: 49, z2: 49, z3: 128}) self.assertEqual( new_body.indexing_exprs["index0"], z2 + 49 * z1 + 2401 * ModularIndexing(z3, 1, 64), ) @inductor_config.patch( { "benchmark_kernel": True, "loop_ordering_after_fusion": True, "triton.unique_kernel_names": True, } ) class LoopOrderingTest(TestCase): device = GPU_TYPE def do_acc_test(self, f, *args, cast_fp8=True): expect = f(*args) actual = torch.compile(f)(*args) if cast_fp8: def _cast(x): if isinstance(x, torch.Tensor) and x.dtype in ( torch.float8_e5m2, torch.float8_e4m3fn, ): return x.to(torch.float32) return x # Wordaround the issue that call allclose on fp8 tensor triggers error # RuntimeError: "mul_cuda" not implemented for 'Float8_e4m3fn' expect = tree_map(_cast, expect) actual = tree_map(_cast, actual) self.assertTrue(same(expect, actual, tol=1e-3)) def setUp(self): super().setUp() metrics.reset() def test_for_reordering_reindex(self): """ ComputedBuffer.iter_reoredering_reindex can cause some fusion opportunitiies being skipped. In this test case, Inductor generates 2 triton kernels before. By removing ComputedBuffer.iter_reoredering_reindex, we can fuse those two kernels into a single one. """ def f(x, y): """ Add a matmul since inductor may force layout for output. """ return (x.sum(dim=-1) + 1) @ y A, B = 20, 30 # Make the first 2 dimension not able to merge on purpose so that # ComputedBuffer.iter_reoredering_reindex will be updated. x = rand_strided([A, A, B], [B, B * A + 300, 1], device=GPU_TYPE) y = torch.randn(A, A) self.do_acc_test(f, x, y) self.assertEqual(1, metrics.generated_kernel_count) expected_num_bytes = 0 expected_num_bytes += A * A * B + A * A # for the fused reduction expected_num_bytes += A * A * 3 # for matmul expected_num_bytes *= x.itemsize self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed) def test_apbt_realize(self): M = 1024 N = 2048 def f(x, y): """ There will be 2 kernels being generated without loop ordering after fusion: https://gist.github.com/shunting314/44df83f71de2c110232c50ac6638ed69 """ x = realize(x * 2) y = realize(y * 3) return x + y x = torch.randn(M, N) y = torch.randn(N, M).t() self.do_acc_test(f, x, y) self.assertEqual(1, metrics.generated_kernel_count) def test_sum_and_t(self): N = 1024 def f(x): return x.sum(dim=-1), x.t().contiguous() x = torch.randn(N, N * 2) self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) def test_pw_outer_red(self): def f(x): x = realize(x + 1) return x.sum(dim=[0, 1]) # make the first 2 dimension small so we don't split the reduction x = torch.randn(2, 4, 512) self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) def test_pw_outer_red_2(self): """ The pointwise kernel is a fused kernel """ def f(x): x = realize(x + 1) x = realize(x - 2) x = realize(x * 3) return x.sum(dim=[0, 1]) # make the first 2 dimension small so we don't split the reduction x = torch.randn(2, 4, 512) self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) @inductor_config.patch(split_reductions=False) def test_different_reduction_order(self): """ We should not reorder loops in this case. Since reordering loops does not help! """ def f(x): return x.sum(dim=0), x.sum(dim=1) x = torch.randn(1024, 2048) self.do_acc_test(f, x) self.assertEqual(2, metrics.generated_kernel_count) self.assertEqual(0, metrics.num_loop_reordering) def test_keep_fake_dep(self): """ In this model, there are fake dependencies (StarDep) between Scatter and a following mutation kernel that computes the gradients of the embedding tables. When we do loop reordering for the mutation kernel, we re-analyze the node's dependencies. But the analysis result does not contains those fake dependencies. Have to add them back manually. """ V = 2048 hidden_size = 64 max_seqlen = 512 batch_size = 8 class Model(nn.Module): def __init__(self): super().__init__() self.word_embeddings = nn.Embedding(V, hidden_size) self.position_embeddings = nn.Embedding(max_seqlen, hidden_size) self.layer_norm = nn.LayerNorm(hidden_size) def forward(self, input_ids, labels, position_ids): emb = self.word_embeddings(input_ids) + self.position_embeddings( position_ids ) return self.layer_norm(emb) m = Model() @torch.compile def f(*args): m(*args).sum().backward() input_ids = torch.randint(0, V, (batch_size, max_seqlen)) labels = torch.randint(0, V, (batch_size, max_seqlen)) position_ids = torch.arange(max_seqlen)[None, :] # Make sure this line does not raise exceptions. If we miss # fake dependencies after loop reordering, we may get exception that # some buffer is used before being defined. f(input_ids, labels, position_ids) def test_different_broadcast_shapes(self): def f(x, y, c): return x + c, y + c x = torch.randn(4, 256, 1024) y = torch.randn(2, 512, 1024) c = torch.randn(1024) self.do_acc_test(f, x, y, c) # The two kernels are not fused due to c is broadcasted self.assertEqual(2, metrics.generated_kernel_count) def test_view(self): """ Passing this test relies that we compare normalized MemoryDep. Normlaization here means merging contiguous loops. To make loop reordering work, we don't merge loops when creating SchedulerNode. Thus we need explicitly normalize MemoryDep when we check if two MemeoryDep matches. """ def f(x): y = x.sin() x = realize(x.view(10, 10)) return x, y x = torch.randn(100) self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+") def test_fp8_cast_and_t(self): """ This test repros the not able to fuses issue in https://github.com/pytorch/pytorch/issues/130015 for fp8 cast and transpose """ def f(x, scale): x = x * scale x = x.clamp(-1 * E4M3_MAX_POS, E4M3_MAX_POS) x = x.to(torch.float8_e4m3fn) x_t = x.t().contiguous().t() return x, x_t x = torch.randn(4096, 4096, dtype=torch.bfloat16) scale = torch.Tensor([10.0]).to(GPU_TYPE) E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max self.do_acc_test(f, x, scale) self.assertEqual(1, metrics.generated_kernel_count) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+") def test_fp8_pattern_2(self): """ This test repros the fp8 fusion relation issue here: https://github.com/pytorch/pytorch/issues/133242 """ ref_dtype = torch.bfloat16 M, K = 4096, 4096 input_tensor = torch.randn( M, K, device=GPU_TYPE, dtype=ref_dtype, requires_grad=False ) scale = torch.Tensor([10.0]).to(GPU_TYPE) E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max def test_pattern2(tensor_x_inp, scale_x): tensor_x = tensor_x_inp * scale_x tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) tensor_fp8 = tensor_x.to(torch.float8_e4m3fn) tensor_x_t = (tensor_x_inp * scale_x).t() tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn) tensor_fp8_t = tensor_fp8_t.contiguous().t() return (tensor_fp8, tensor_fp8_t) test_pattern = torch.compile(test_pattern2) tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale) self.assertEqual(1, metrics.generated_kernel_count) expected_numbytes = scale.nbytes # scalar expected_numbytes += input_tensor.nbytes # input expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output self.assertEqual(expected_numbytes, metrics.num_bytes_accessed) def test_outer_dimension_softmax(self): """ This test repros the not able to fuse problem for outer dimension softmax reported here: https://github.com/pytorch/pytorch/issues/93718 Perf data on h100: - without loop ordering after fusion 0.564 ms - with loop ordering after fusion 0.302 ms This is 1.87x speedup. """ x = torch.randn(32, 2**21, device=GPU_TYPE) def f(x): return F.softmax(x, dim=0) self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) def test_outer_dimension_sum_fuse_with_pw(self): """ Test the fusion of an outer dimension sum with a followed pointwise. Perf data on h100: - without loop ordering after fusion 0.436 ms - with loop ordering after fusion 0.260 ms This is 1.68x speedup. """ x = torch.randn(32, 2**21, device=GPU_TYPE) def f(x): return x.sum(dim=0, keepdim=True) + x self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) if DO_PERF_TEST: from triton.testing import do_bench optf = torch.compile(f) print(f"ms={do_bench(lambda: optf(x))}") # Disable split reduction to make it easier to calculate the expected # number of bytes accessed. In this case, split reduction does not # help perf much. @inductor_config.patch(split_reductions=False) def test_fuse_reduction_with_tiled_pw(self): def f(x): y = torch.sum(torch.sum(x, dim=-1)) z = x / 10.0 z_t = z.t().contiguous().t() return y, z, z_t # use this input sizes to test for perf if DO_PERF_TEST: M, N = 1024 * 32, 1024 * 8 else: M, N = 200, 100 x = torch.randn(M, N, device=GPU_TYPE) actual = f(x) opt_f = torch.compile(f) expected = opt_f(x) self.assertTrue(same(actual, expected, tol=1e-3)) # We should fuse the first sum with the two pointwise. # Overall we read x once for all these three kernels and write # out 2 buffers with the same size as x. # This should be sort of 'optimal' for this workload. expected_numbytes = x.nbytes * 3 # A small amount of extra memory access for: # - store output for the first reduction # - load input for the second redution # - store output for the second reduction expected_numbytes += (M * 2 + 1) * x.itemsize print(expected_numbytes) self.assertEqual(expected_numbytes, metrics.num_bytes_accessed) if DO_PERF_TEST: from triton.testing import do_bench ms = do_bench(lambda: opt_f(x)) print(f"{ms=:.3f}") @inductor_config.patch( { "max_autotune": True, "max_autotune_gemm_backends": "TRITON", "test_configs.max_mm_configs": 4, } ) @skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune") def test_interaction_with_triton_template(self): """ Make sure the dependency prefix for TritonTempalate and its prologue match. """ @torch.compile def f(x, y): return (x.expand([1, y.shape[0]]) + 1) @ y x = torch.randn([1, 1], device=GPU_TYPE) y = torch.randn([64, 128], device=GPU_TYPE) out, code = run_and_get_code(f, x, y) # well when benchmark_kernel flag is on, we have one more .run # call in the benchmarking code. FileCheck().check("def call(").check_count( ".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True ).run(code[0]) @inductor_config.patch( { "max_autotune": True, "max_autotune_gemm_backends": "TRITON", "test_configs.max_mm_configs": 4, } ) @skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune") def test_interaction_with_multi_template(self): """ Skip MultiTemplateBuffer during loop reordering """ @torch.compile def f(x, y): return (x @ y), x + 1 N = 2 x = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16) y = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16) out, code = run_and_get_code(f, x, y) # didn't fuse due to small savings FileCheck().check_count("@triton.jit", 2, exactly=True).run(code[0]) def test_fuse_with_scalar_shared_memory(self): """ Make sure if we can fuse two nodes sharing a scalar before, we can still do it with LOAF applied. This is not really a big deal. But some tests rely on this and less number of kernels has some small benefits. """ @torch.compile def f(x): return torch.mean(x) x = torch.randn([5, 5], device=GPU_TYPE) out, code = run_and_get_code(f, x) FileCheck().check_count("@triton.jit", 1, exactly=True).run(code[0]) def test_3dred_pw_2d_outer_red(self): """ Test a pattern as follows. We have a 3d contiguous tensor [m, n, k] as input. 1. do reduction on the k dimension and get a [m, n] tensor 2. do a pointwise operation on this [m, n] tensor (and realize the computation) 3. do a outer reduction on the output of step 2 on the m dimension. Each of these step generate a kernel before fusion. Without any loop reorder, kernel 1 and kernel 2 will get fused. And kernel 3 will be separeate. But if we reorder the loop for kernel 2, then kernel 2 will get fused with kernel 3. And the fused kernel-2-3 can not be fused with kernel 1. The older version of LOAF algorithm will do reorder in this case. But there is no real benefits. There are even some slight downsides 1. the original fusion without loop reordering is more natural 2. fusion kernel 1 with kernel 2 may help precision when the output of kernel 1 is in low precision. By fusion kernel 1 and kernel 2, the pointwise operation will operate on fp32 precision thanks to fusion. """ M, N, K = 64, 64, 64 def f(x): x = x.sum(dim=-1) x = x + 1 # can be more complex like sigmoid or other ops return x, x.sum(dim=0) x = torch.randn(M, N, K, device=GPU_TYPE) self.do_acc_test(f, x) self.assertEqual(0, metrics.num_loop_reordering) @inductor_config.patch( { "triton.unique_kernel_names": True, "loop_ordering_after_fusion": True, "triton.max_tiles": 3, "triton.coalesce_tiling_analysis": True, } ) @instantiate_parametrized_tests class MemoryCoalescingTest(MockSchedulerTest): """Tests for memory coalescing analysis with specific tensor sizes.""" device = GPU_TYPE _exit_stack = None def setUp(self): super().setUp() metrics.reset() def _create_buffer(self, name, sizes): """Create a buffer with specified sizes""" strides = ir.FlexibleLayout.contiguous_strides(sizes) box = ir.TensorBox.create( ir.Buffer( name=name, layout=ir.FixedLayout( torch.device(self.device), dtype=torch.float32, size=sizes, stride=strides, ), ) ) box_loader = box.make_loader() def inner_fn(index): return box_loader(index) * 2 buf = ir.Pointwise.create( device=box.get_device(), dtype=box.get_dtype(), inner_fn=inner_fn, ranges=box.get_size(), ) buf.realize() computed_buf = buf.data.data computed_buf.decide_layout() return computed_buf def _create_scheduler_node(self, buf): s = SchedulerNode(V.graph.scheduler, buf) s.min_order = 0 s.max_order = 100 return s @parametrize( "inps", ( ((128, 384, 196), (768, 64, 196), (128, 6, 64, 196)), ((64,), (16, 4), (16, 4)), ((5, 6), (3, 10), (30,)), ((5, 6, 20), (3, 10, 20), (30, 20)), ), ) def test_inferred_splits(self, inps): """ Test memory coalescing analysis with the specified tensor sizes. Using direct SchedulerNode creation with sizes (128, 384, 196) and (768, 64, 196). """ s1, s2, expected_size = inps # Create buffers with the specified sizes buf1 = self._create_buffer("buffer1", s1) buf2 = self._create_buffer("buffer2", s2) # Create scheduler nodes snode1 = self._create_scheduler_node(buf1) snode2 = self._create_scheduler_node(buf2) # Create a fused node fused_node = torch._inductor.scheduler.FusedSchedulerNode.fuse(snode1, snode2) from torch._inductor import tiling_utils fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(fused_node) var_ranges = fused_norm_read_writes.var_ranges self.assertEqual(list(var_ranges.values()), list(expected_size)) def test_remapped_reads(self): from torch._inductor import tiling_utils def fn(nodes): assert len(nodes) == 1 fused_norm_read_writes = tiling_utils.extract_normalized_read_writes( nodes[0] ) self.assertTrue(len(fused_norm_read_writes.var_ranges) == 2) # both reads remapped correctly FileCheck().check("4*n0 + n1").run( repr(fused_norm_read_writes.reads.keys()) ) FileCheck().check("n0 + 4*n1").run( repr(fused_norm_read_writes.reads.keys()) ) return nodes with torch._inductor.config.patch(_post_fusion_custom_pass=fn): @torch.compile() def foo(x, y): return x + y foo( torch.rand([4, 4], device=GPU_TYPE), torch.rand([4, 4], device=GPU_TYPE).T, ) def test_remapped_reads_split(self): from torch._inductor import tiling_utils def fn(nodes): self.assertTrue(len(nodes) == 1) fused_norm_read_writes = tiling_utils.extract_normalized_read_writes( nodes[0] ) inp_node_reads = nodes[0].get_nodes()[1]._body.get_read_exprs() node_ranges = nodes[0].get_nodes()[1]._body.var_ranges self.assertTrue(len(node_ranges) == 1) self.assertTrue(next(iter(node_ranges.values())) == 36) var = next(iter(node_ranges.keys())) r = FloorDiv(var, 6) + 6 * ModularIndexing(var, 1, 6) self.assertTrue(r in inp_node_reads) # mapped reads self.assertTrue(list(fused_norm_read_writes.var_ranges.values()) == [6, 6]) n0, n1 = list(fused_norm_read_writes.var_ranges.keys()) # translation of above is n0 + 6 * n1 self.assertTrue((n0 + 6 * n1) in fused_norm_read_writes.reads.keys()) return nodes with torch._inductor.config.patch(_post_fusion_custom_pass=fn): @torch.compile() def foo(x, y): return ( x + y ).contiguous().flatten() + torch.ops._inductor_test.realize( (y.T + 1).flatten() ) foo( torch.rand([6, 6], device=GPU_TYPE), torch.rand([6, 6], device=GPU_TYPE).T, ) def test_reduction_pointwise(self): # test one pw var, one red var from torch._inductor import tiling_utils def fn(nodes): self.assertTrue(len(nodes) == 1) fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0]) i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars self.assertTrue(len(i_vars) == 1) self.assertTrue(len(r_vars) == 1) # single write to index var self.assertTrue( fused_rw.index_vars[0] == next(iter(fused_rw.writes.keys())) ) # the write to the fused intermediary node should be removed self.assertTrue(len(fused_rw.writes) == 1) # single read self.assertTrue(len(fused_rw.reads) == 1) # that is applied to two bufs self.assertTrue(len(next(iter(fused_rw.reads.values()))) == 2) # and the read should be in terms of the index + reduce var, # even though node is pointwise self.assertTrue(256 * i_vars[0] + r_vars[0] in fused_rw.reads) return nodes with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): @torch.compile() def foo(x, y): out = torch.ops._inductor_test.realize(x + y) return out.sum(dim=1) foo( torch.rand(256, 256, device=GPU_TYPE), torch.rand(256, 256, device=GPU_TYPE), ) def test_reduction_no_pointwise(self): # test one pw var, one red var from torch._inductor import tiling_utils def fn(nodes): self.assertTrue(len(nodes) == 1) fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0]) i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars self.assertTrue(len(i_vars) == 0) self.assertTrue(len(r_vars) == 1) return nodes with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): @torch.compile() def foo(x): return x.sum() foo(torch.rand(1024, device=GPU_TYPE)) def test_coalescing(self): from torch._inductor import tiling_utils # Define symbolic variables i, j, n, m = sympy.symbols("i j n m", integer=True) # Test cases: (expression, var_ranges, expected_result) test_cases = [ # Simple direct case (i + j * 5, {i: 10, j: 8}, i), # Floor division case (i + FloorDiv(j, 2), {i: 4, j: 8}, i), # Modular indexing (i * 10 + ModularIndexing(j, 1, 3), {i: 5, j: 10}, j), # Case with no coalescing variable (i * 2 + j * 3, {i: 8, j: 5}, None), # Division case (i / 2, {i: 10}, None), # More complex floor division (j + FloorDiv(i, 3), {i: 6, j: 12}, j), # Addition inside modular indexing (ModularIndexing(i + 3, 1, 6), {i: 8, j: 12}, i), ] for expr, var_ranges, expected in test_cases: # Test the function result = tiling_utils.find_coalesced_var(expr, var_ranges) self.assertEqual(result, expected) @parametrize("downcast_transposed_v", (False, True)) def test_tiled_coalesce_analysis(self, downcast_transposed_v): # test one pw var, one red var from torch._inductor import tiling_utils def fn(nodes): self.assertTrue(len(nodes) == 1) coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0]) i_vars = coalesce_analysis.norm_read_writes.index_vars # because output is contiguous, second dimension should # coalesce twice as many bytes as first dimension # if not downcasted # if downcasted, should be equal, bc larger dtype size # we also weight writes x 2 cont_reads = coalesce_analysis.coalesced_by_var[i_vars[1]] t_reads = coalesce_analysis.coalesced_by_var[i_vars[0]] if not downcast_transposed_v: self.assertEqual(cont_reads, t_reads * 3) else: self.assertEqual(cont_reads, t_reads * 1.5) return nodes with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): @torch.compile() def foo(x, y): return x + y.to(x.dtype) y_dtype = torch.float if not downcast_transposed_v else torch.float64 foo( torch.rand(256, 256, device=GPU_TYPE), torch.rand(256, 256, device=GPU_TYPE, dtype=y_dtype).T, ) def test_solve_for_zero(self): from torch._inductor import tiling_utils x, y = sympy.symbols("x y", integer=True) # Test cases: (expression, expected_result) test_cases = [ # Simple linear expressions (x + 5, (-5)), (2 * x - 10, (5)), # Constant expressions (should return None) (sympy.Integer(7), None), (sympy.Integer(0), None), # FloorDiv cases (should return None per function) (FloorDiv(x, 2), None), (FloorDiv(x, 2) + 5, None), # ModularIndexing cases (ModularIndexing(x, 1, 5), (5)), (ModularIndexing(x, 1, 3), (3)), # Expressions with no constant solution (x**2 + 1, None), # No real solution ] for expr, expected in test_cases: result = tiling_utils.solve_for_zero(expr) self.assertEqual(result, expected) def test_solve_for_tiling(self): from torch._inductor import tiling_utils x = sympy.Symbol("x", integer=True) test_cases = [ # Simple linear cases that coalesce (3 * x, None), # # # # Expression with no free symbols # (sympy.Integer(5), None), (x / 3, 3), (FloorDiv(x * 2, 6), 3), # # ModularIndexing expressions (ModularIndexing(FloorDiv(x, 4), 1, 64), 4), (x + ModularIndexing(x, 1, 5), None), (x**2, None), # Non-linear, diff is not constant (4096 * (ModularIndexing(32 * x, 1, 2048)) + FloorDiv(x, 64), 64), (4096 * (ModularIndexing(x, 1, 2048)) + FloorDiv(x, 2048), 2048), ] for expr, expected in test_cases: result = tiling_utils.solve_for_tiling(expr) self.assertEqual(result, expected) def test_induced_fused_tiling(self): from torch._inductor import tiling_utils def fn(nodes): self.assertTrue(len(nodes) == 1) coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0]) self.assertEqual(coalesce_analysis.suggested_split.tiling_factor, 64) return nodes with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): def forward(permute): clone = torch.ops.aten.clone.default( permute, memory_format=torch.contiguous_format ) view_2 = torch.ops.aten.view.default(clone, [-1, 32]) amax_1 = torch.ops.aten.amax.default(view_2, [1]) return amax_1 XDIM = 2048 YDIM = 4096 arg0_1 = torch.randn([XDIM, YDIM], device=GPU_TYPE, dtype=torch.bfloat16) permute = torch.ops.aten.permute.default(arg0_1, [1, 0]) out, code = run_and_get_code(torch.compile(forward), (permute)) self.assertEqual(out, forward(permute)) FileCheck().check("YBLOCK").check("XBLOCK").run(code[0]) layouts = ("cont", "NHWC", "T") @inductor_config.patch( { "triton.unique_kernel_names": True, "loop_ordering_after_fusion": True, "triton.coalesce_tiling_analysis": True, } ) @instantiate_parametrized_tests class TestTiling(TestCase): def T(self, layout: str): SIZE_A = 128 SIZE_B = 256 SIZE_C = 512 if layout == "cont": return torch.rand(SIZE_A, SIZE_B, SIZE_C, device=GPU_TYPE).unsqueeze(0) elif layout == "T": return ( torch.rand(SIZE_A, SIZE_B, SIZE_C, device=GPU_TYPE) .transpose(1, 2) .contiguous() .transpose(1, 2) .unsqueeze(0) ) else: assert layout == "NHWC" return torch.rand([1, SIZE_A, SIZE_B, SIZE_C], device=GPU_TYPE).to( memory_format=torch.channels_last ) @parametrize("a", layouts) @parametrize("b", layouts) def test_pointwise(self, a, b): def foo(x, y): return x + y x, y = self.T(a), self.T(b) res, code = run_and_get_code(torch.compile(foo), x, y) if a != b: FileCheck().check("ynumel").run(code[0]) else: FileCheck().check_not("ynumel").run(code[0]) self.assertEqual(res, foo(x, y)) def test_tiled_reduction(self): def f(a, b): return (a * b).sum(dim=-1) N = 512 inps = ( torch.randn(N, N, N, device=GPU_TYPE).permute(2, 1, 0), torch.randn(N, N, N, device=GPU_TYPE).permute(1, 2, 0), ) f_c = torch.compile(f) out, code = run_and_get_code(f_c, *inps) FileCheck().check_dag("xnumel = 512").check_dag("ynumel = 512").check_dag( "rnumel" ).run(code[0]) self.assertEqual(out, f(*inps), atol=0.001, rtol=0.04) def test_3d_pointwise(self): inps = (self.T("cont"), self.T("T"), self.T("NHWC")) def f(x, y, z): return x + y + z f_c = torch.compile(f) out, code = run_and_get_code(f_c, *inps) FileCheck().check_dag("znumel").check_dag("ynumel").check_dag("xnumel").run( code[0] ) self.assertEqual(out, f(*inps)) def test_cat(self): # test unwrapping Identity def f(x, y): return torch.cat((x, y)) + 1 x = self.T("cont") y = self.T("T") inps = (x, y) f_c = torch.compile(f) out, code = run_and_get_code(f_c, *inps) FileCheck().check_dag("ynumel").check_dag("xnumel").run(code[0]) self.assertEqual(out, f(*inps)) def test_penalized_small_dim(self): x = torch.rand([2000, 1], device=GPU_TYPE) y = torch.rand([4, 1], device=GPU_TYPE).T # don't tile when it doesn't affect total coalesced mem accesses much def f(x, y): return x + y inps = (x, y) f_c = torch.compile(f) out, code = run_and_get_code(f_c, *inps) FileCheck().check_not("ynumel").check_dag("xnumel").run(code[0]) self.assertEqual(out, f(*inps)) def test_mutation_deps(self): def f(x): return x.add_(1) x = self.T("cont") from torch._inductor import tiling_utils def fn(nodes): self.assertTrue(len(nodes) == 1) coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0]) assert coalesce_analysis is not None reads = coalesce_analysis.norm_read_writes.reads writes = coalesce_analysis.norm_read_writes.writes self.assertTrue(len(reads) == 1 and len(writes) == 1) self.assertEqual( list(coalesce_analysis.norm_read_writes.reads.values()), [OrderedSet(("arg0_1",))], ) self.assertEqual( list(coalesce_analysis.norm_read_writes.writes.values()), [OrderedSet(("buf1",))], ) return nodes with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): torch.compile(f)(x) if __name__ == "__main__": if HAS_GPU: run_tests()