[inductor] Fix 3d tiling (#141709)

Fixes #141121

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141709
Approved by: https://github.com/eellison
This commit is contained in:
Jason Ansel
2024-11-27 14:24:43 -08:00
committed by PyTorch MergeBot
parent ad3986498a
commit ca9bfa1a38
6 changed files with 39 additions and 6 deletions

View File

@ -1495,6 +1495,33 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
device_stats2["active.all.peak"] <= device_stats["active.all.peak"]
)
@config.patch(
{
"triton.prefer_nd_tiling": True,
"triton.max_tiles": 3,
}
)
def test_3d_tiling(self):
full_size, view_size, num_block_pointers, num_tiles = (
(5, 5, 5, 5, 5),
(3, 3, 5, 3, 5),
1,
2,
)
GPU_TYPE = "cuda"
def get_input() -> torch.Tensor:
device = torch.device(GPU_TYPE)
full = torch.randn(full_size).to(device)
return torch.as_strided(full, view_size, full.stride())
a, b = get_input(), get_input()
opt_fn = torch.compile(functools.partial(torch.add))
result, (code,) = run_and_get_code(opt_fn, a, b)
self.assertEqual(result, a + b)
self.assertIn("znumel", code)
def test_repeated_masked_load(self):
target_size = (8, 2)
mem_eff_temporal_upsampling_interp_chunks = 2

View File

@ -153,14 +153,14 @@ class TritonSymbols:
block_offsets = {
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
}
block_sizes = {
symt: sympy.Symbol(
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
}
@classmethod
@ -1542,7 +1542,7 @@ class TritonKernel(SIMDKernel):
else:
# var is one of xN, yN or rN
assert symbol_is_type(
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK)
), var.name
mask_vars.add(f"{var.name[0]}mask")

View File

@ -960,6 +960,10 @@ class triton:
dense_indexing = False
# limit tiling dimensions
# - max_tiles=1 disables tiling
# - max_tiles=2 is the default
# - max_tiles=3 is experimental and may have bugs
# higher values are unsupported
max_tiles = 2
# Prefer higher dimensional tilings. This simplifies indexing expressions, making

View File

@ -4155,7 +4155,7 @@ class ComputedBuffer(OperationBuffer):
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
iter_ranges,
reduce_ranges,
prefix="z",
prefix="p",
)
body = LoopBody(
body,

View File

@ -215,7 +215,7 @@ class LoopBody:
# use the original symbol prefix
# Can try to optimize if this is a bottleneck for compilation time
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
iter_sizes, reduce_sizes, prefix="z"
iter_sizes, reduce_sizes, prefix="p"
)
new_body2 = LoopBody(
new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
@ -259,7 +259,7 @@ class LoopBody:
# use the original symbol prefix so we can do multiple round of reordering
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="z" # type: ignore[arg-type]
*new_sizes, prefix="p" # type: ignore[arg-type]
)
new_body = LoopBody(
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2

View File

@ -47,6 +47,7 @@ class SymT(Enum):
# Inductor: iteration domain for blockIdx.x/blockIdx.y
XBLOCK = auto()
YBLOCK = auto()
ZBLOCK = auto()
# Inductor: this is used solely for dynamic_reshape_indexer
VIEW = auto()
# Alternate (non-modular) indexing used in halide kernels
@ -70,6 +71,7 @@ prefix_str = {
SymT.TEMPLATE_INDEX: "idx",
SymT.XBLOCK: "x",
SymT.YBLOCK: "y",
SymT.ZBLOCK: "z",
SymT.INDIRECT: "indirect", # false aliasing?
SymT.VIEW: "view",
SymT.HALIDE: "h",