Compare commits

..

6 Commits

Author SHA1 Message Date
86631eccda [Inductor] Remove stride-0 dimensions from more complex block pointers (#135557)
Related issue: #125077

### Feature
Inductor tries to remove dimensions with stride 0 from block pointers. Rather than loading with stride 0, it's more efficient to load a smaller block pointer, then use `tl.broadcast_to` to broadcast it up to the desired size. This already worked for simpler block pointers, but it was disabled for more complex block pointers which used `tl.reshape` to change the dimensionality after loading.

This PR generalizes the approach to work for all block pointers. The idea is to first reshape, adding singleton dimensions, then broadcast those singletons up to something larger, then reshape again to the final output shape. For readability, we emit this code only if it actually does something. Simpler loads will just have `tl.load`.

Here's an example of a complicated kernel that uses `reshape` -> `load` -> `reshape`. (The first reshape is actually the slice `[None,None,:]`).
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = (xindex // 8)
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
    tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tmp2.to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```

Before this PR, we would have stride-0 dimensions:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = (xindex // 8)
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
    tmp1 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr1, shape=[8, 1, 8], strides=[8, 0, 0], block_shape=[((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))], order=[2, 1, 0], offsets=[(xoffset // 8), 0, xoffset % 8]), boundary_check=[0], eviction_policy='evict_last'), [XBLOCK])
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```

Here's a simpler example where we use 2D tiling. In this case we don't actually need the broadcast. The broadcast is implied via a slice adding a new singleton dimension. This code is not changed by this PR, but it's important to know that we don't accidentally insert unnecessary broadcasts.
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 8
    xnumel = 8
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
    tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tmp2.to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')
```
### Test Plan
Added a new expecttest to check the emitted code for broadcast addition. Looking at the test, we can see that stride 0 dimensions are removed. (This test generated the example kernels in the previous section.)

This change also removed a stride-0 dimension in an existing block pointer test. I updated the expected code accordingly.

Bonus: I noticed that the test parametrization for `config.prefer_nd_tiling` wasn't working as intended. It ended up always setting this option to `True`. Fixed it so we get the intended test coverage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135557
Approved by: https://github.com/shunting314, https://github.com/jansel

Co-authored-by: Yueming Hao <yhao@meta.com>
2024-09-27 04:01:40 +00:00
2c5f5e303a [inductor] Triton codegen: Use scalar when creating f64 constant instead of 1-element tensor (#136594)
Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this:

`tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])
`

https://github.com/pytorch/pytorch/pull/135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want?

Differential Revision: [D63465169](https://our.internmc.facebook.com/intern/diff/D63465169)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136594
Approved by: https://github.com/mengluy0125, https://github.com/jansel
2024-09-27 04:01:09 +00:00
a2d2a30311 Add torch._dynamo.config.fail_on_cache_limit_hit (#136767)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136767
Approved by: https://github.com/albanD, https://github.com/jansel
ghstack dependencies: #136533
2024-09-27 03:58:00 +00:00
2521cd3874 Skip kernel saving if already existed. (#136389)
Summary:
We skip the save_gpu_kernel if kernel is being saved already.
This would give us a more accurate Triton profiling result. The following trace shows before/after the change for a benchmarking of a trivial addmm:

Before:
<img width="1255" alt="Screenshot 2024-09-23 at 10 26 53 AM" src="https://github.com/user-attachments/assets/5aea05ef-6ef0-464c-8da9-17b31c97b43a">

After:
<img width="910" alt="Screenshot 2024-09-23 at 10 27 03 AM" src="https://github.com/user-attachments/assets/488b7d4f-268f-41cf-8553-cb16ceeae118">

We can see that before the change, the benchmarking includes two parts,
(1) The overhead of our triton_heuristic call, which includes the save/get, and the (expensive) hash computation.
(2) The exact computation of Triton kernel.

We see that (1) accounts >50% of time, which makes kernel selection for profiling often choose aten kernels over Triton kernels.

Test Plan:
Existing OSS CI
[Redacted, Some internal model results in D63441430]

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136389
Approved by: https://github.com/desertfire
2024-09-27 03:03:28 +00:00
d1382aaf3d skip test_out_of_memory for jetson (#133270)
Skip test_out_of_memory in test/test_cuda.py on Jetson as OOM reporting in Jetson has issues due to partially missing NVML support. cc @eqy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133270
Approved by: https://github.com/eqy, https://github.com/albanD, https://github.com/seemethere
2024-09-27 02:36:48 +00:00
26869d38e1 [Inductor] Further solve missing aoti_torch_check symbole issue (#136775)
Summary: https://github.com/pytorch/pytorch/pull/136669 didn't resolve all the internal test failures. Add more tests to OSS CI to catch the remaining issues, and fix some internal TARGETS dependency.

Differential Revision: D63473744

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136775
Approved by: https://github.com/henrylhtsang
2024-09-27 02:26:49 +00:00
12 changed files with 271 additions and 85 deletions

View File

@ -375,9 +375,8 @@ test_inductor_cpp_wrapper_abi_compatible() {
mkdir -p "$TEST_REPORTS_DIR"
echo "Testing Inductor cpp wrapper mode with TORCHINDUCTOR_ABI_COMPATIBLE=1"
# cpu stack allocation causes segfault and needs more investigation
PYTORCH_TESTING_DEVICE_ONLY_FOR="" python test/run_test.py --include inductor/test_cpu_cpp_wrapper
python test/run_test.py --include inductor/test_cuda_cpp_wrapper
python test/run_test.py --include inductor/test_cuda_cpp_wrapper inductor/test_cpu_repro
TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \
--training --inductor --disable-cudagraphs --only vit_base_patch16_224 \

View File

@ -8,6 +8,7 @@ import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._logging
from torch._dynamo.exc import FailOnCacheLimitHit
from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
@ -203,6 +204,19 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
"expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
)
@torch._dynamo.config.patch(cache_size_limit=1, fail_on_cache_limit_hit=True)
def test_fail_on_cache_limit_hit(self):
@torch.compile(backend="eager")
def func(b, a):
if a:
return b * 2
else:
return b + 1
func(torch.randn(5), True)
with self.assertRaises(FailOnCacheLimitHit):
func(torch.randn(5), False)
@torch._dynamo.config.patch("cache_size_limit", 32)
def test_multiple_guard_fails(self):
failure_reasons = []

View File

@ -12036,7 +12036,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.reshape(tl.load(block_ptr0, boundary_check=[3], padding_option='zero', eviction_policy='evict_last'), [XBLOCK, RBLOCK])
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
)

View File

@ -105,7 +105,7 @@ class TritonBlockPointerTest(InductorTestCase):
foo, *inputs, expected_num_block_pointers=expected_num_block_pointers
)
@parametrize("prefer_nd_tiling", [(False, True)])
@parametrize("prefer_nd_tiling", [False, True])
@parametrize(
"full_size,view_size,stride,offset,require_block_ptr",
[
@ -176,7 +176,7 @@ class TritonBlockPointerTest(InductorTestCase):
config_patches={"triton.prefer_nd_tiling": prefer_nd_tiling},
)
@parametrize("prefer_nd_tiling", [(False, True)])
@parametrize("prefer_nd_tiling", [False, True])
@parametrize(
"x_size,y_size",
[
@ -230,7 +230,59 @@ class TritonBlockPointerTest(InductorTestCase):
config_patches={"triton.prefer_nd_tiling": prefer_nd_tiling},
)
@parametrize("prefer_nd_tiling", [(False, True)])
@parametrize("prefer_nd_tiling", [False, True])
def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool):
"""
Test that we emit tl.broadcast_to instead of using strides of 0.
"""
full_shape = (8, 8)
col_shape = (full_shape[1], 1)
device = torch.device(GPU_TYPE)
full = torch.randn(full_shape).to(device)
col = torch.as_strided(full, col_shape, full.stride())
# Expect 3 block pointers: 2 inputs one output
result, (triton_code,) = self.run_and_compare(
torch.add,
full,
col,
expected_num_block_pointers=3,
config_patches={
"triton.prefer_nd_tiling": prefer_nd_tiling,
},
)
# Check the code for broadcasts.
# We shouldn't see any strides of 0.
load_lines, store_lines = tuple(
[line for line in triton_code.split("\n") if substr in line]
for substr in ("tl.load", "tl.store")
)
if prefer_nd_tiling:
self.assertExpectedInline(
"\n".join(load_lines),
"""\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]""", # noqa: B950
)
self.assertExpectedInline(
"\n".join(store_lines),
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tmp2.to(tl.float32), boundary_check=[0, 1])""", # noqa: B950
)
else:
self.assertExpectedInline(
"\n".join(load_lines),
"""\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
)
self.assertExpectedInline(
"\n".join(store_lines),
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tmp2.to(tl.float32), boundary_check=[0])""", # noqa: B950
)
@parametrize("prefer_nd_tiling", [False, True])
@parametrize(
"view_size,num_block_pointers,num_triton_kernels",
[

View File

@ -232,6 +232,9 @@ class TestCuda(TestCase):
device_properties_no_argument = torch.cuda.get_device_properties()
self.assertEqual(current_device_properties, device_properties_no_argument)
@unittest.skipIf(
IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
)
def test_out_of_memory(self):
tensor = torch.zeros(1024, device="cuda")

View File

@ -51,6 +51,12 @@ accumulated_cache_size_limit = 256
# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit
skip_code_recursive_on_cache_limit_hit = True
# raise a hard error if cache limit is hit. If you are on a model where you
# know you've sized the cache correctly, this can help detect problems when
# you regress guards/specialization. This works best when cache_size_limit = 1.
# [@compile_ignored: runtime_behaviour]
fail_on_cache_limit_hit = False
# whether or not to specialize on int inputs. This only has an effect with
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
# inputs. Note that assume_static_by_default will also cause ints to get

View File

@ -70,6 +70,7 @@ from .exc import (
augment_exc_message,
BackendCompilerFailed,
CacheLimitExceeded,
FailOnCacheLimitHit,
format_error_msg,
InternalTorchDynamoError,
SkipCodeRecursiveException,
@ -858,7 +859,11 @@ def _compile(
format_guard_failures(),
troubleshooting_url,
)
if config.skip_code_recursive_on_cache_limit_hit and justknobs_check(
if config.fail_on_cache_limit_hit:
raise FailOnCacheLimitHit(
f"{limit_type} reached, because fail_on_cache_limit_hit = True this is a HARD failure"
)
elif config.skip_code_recursive_on_cache_limit_hit and justknobs_check(
"pytorch/compiler:skip_code_recursive_on_cache_limit_hit"
):
raise CacheLimitExceeded(f"{limit_type} reached")

View File

@ -188,6 +188,13 @@ class IncorrectUsage(Exception):
pass
# TODO: I'm a little uncertain about what error classification we should have
# for this. This is potentially a user error, but regressions in
# specialization in PyTorch proper could also trigger this problem
class FailOnCacheLimitHit(Exception):
pass
class ObservedException(TorchDynamoException):
# An exception observed during the tracing. This exception is used by Dynamo to handle exceptions.
pass

View File

@ -26,6 +26,7 @@
#include <c10/util/generic_math.h>
#include <c10/util/Half.h>
#include <c10/util/TypeCast.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
#define INDUCTOR_USE_VECTOR_TYPES() 1

View File

@ -17,6 +17,7 @@ from typing import (
Dict,
Iterable,
List,
no_type_check,
Optional,
Sequence,
Tuple,
@ -29,7 +30,12 @@ import torch
import torch._logging
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.symbol import (
free_symbol_is_type,
prefix_str,
symbol_is_type,
SymT,
)
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
@ -41,6 +47,7 @@ from ..runtime.hints import ReductionHint
from ..runtime.runtime_utils import green_text, yellow_text
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
from ..utils import (
cache_on_self,
get_dtype_size,
IndentedBuffer,
Placeholder,
@ -106,6 +113,13 @@ class IterationRanges:
def symbol(self):
return sympy_index_symbol(self.name)
@property
@cache_on_self
@no_type_check
def symt(self) -> SymT:
prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()}
return prefix_to_symt[self.prefix]
class IterationRangesRoot(IterationRanges):
def __init__(

View File

@ -67,6 +67,7 @@ from .common import (
)
from .simd import (
constant_repr,
IterationRanges,
IterationRangesEntry,
IterationRangesRoot,
pexpr,
@ -129,15 +130,34 @@ def gen_common_triton_imports():
return imports.getvalue()
block_offsets = {
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}
class TritonSymbols:
"""
Stores sympy.Symbol instances and constants associated with triton codegen.
"""
block_sizes = {
symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}
block_offsets = {
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}
block_sizes = {
symt: sympy.Symbol(
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}
@classmethod
def get_block_size(cls, tree: IterationRanges) -> sympy.Symbol:
return cls.block_sizes[tree.symt]
@classmethod
def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol:
return cls.block_offsets[tree.symt]
@classmethod
def max_block_size(cls, tree: IterationRanges) -> int:
return TRITON_MAX_BLOCK[tree.prefix.upper()]
@dataclasses.dataclass
@ -171,7 +191,9 @@ class BlockPtrOptions:
constant_offset: sympy.Expr
order: List[int]
mask_vars: OrderedSet[str]
reshape_suffix: List[str]
broadcast_shape: List[sympy.Expr]
broadcasting_dims: List[bool]
final_shape: List[sympy.Expr]
@property
def shape(self) -> List[sympy.Expr]:
@ -189,6 +211,50 @@ class BlockPtrOptions:
def offsets(self) -> List[sympy.Expr]:
return self.params.offsets
def codegen_broadcast_and_reshape(
self,
value: str,
initial_shape: List[sympy.Expr],
final_shape: List[sympy.Expr],
allow_implicit: bool,
) -> str:
"""
Generate a broadcast and a reshape for the block pointer.
This restores stride-0 dimensions which were removed from the block pointer.
"""
# Reshape to add singletons.
pre_broadcast_shape = [
sympy.Integer(1) if is_broadcasting else dim
for dim, is_broadcasting in zip(
self.broadcast_shape, self.broadcasting_dims
)
]
value = triton_reshape(value, initial_shape, pre_broadcast_shape)
# Broadcast singletons.
# For loads, we can often implicitly broadcast singleton dimensions.
# We need an explicit broadcast for stores, or if the final reshape does more
# than add singletons.
sizevars = V.graph.sizevars
if any(self.broadcasting_dims) and (
not allow_implicit
or len(pre_broadcast_shape) != len(final_shape)
or any(
not (
sizevars.statically_known_equals(pre_dim, 1)
or sizevars.statically_known_equals(pre_dim, post_dim)
)
for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape)
)
):
value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})"
# Reshape to the final shape.
value = triton_reshape(value, self.broadcast_shape, final_shape)
return value
@staticmethod
def create(
*,
@ -198,21 +264,61 @@ class BlockPtrOptions:
mask_vars: OrderedSet[str],
) -> BlockPtrOptions:
"""Helper to create a BlockPtrOptions instance"""
reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
# Only drop broadcast dims if the output has the same
# rank as the block. Otherwise, we will get shape errors.
drop_broadcasts = len(reshape_suffix) == len(params.strides)
sizevars = V.graph.sizevars
broadcasting_dim = [s == 0 for s in params.strides]
for i, is_broadcasting in enumerate(broadcasting_dim):
if is_broadcasting and drop_broadcasts:
# drop any stride==0 dimensions for performance
reshape_suffix[i] = "1"
def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]:
return [sizevars.lookup_precomputed_size(expr) for expr in exprs]
# Look up precomputed sizes
params.shape = lookup_size(params.shape)
params.strides = lookup_size(params.strides)
# Strip out dimensions of stride 0.
# These will be restored with tl.broadcast_to.
broadcasting_dims = [
sizevars.statically_known_equals(stride, 0) for stride in params.strides
]
# Strip out dimensions of size 1.
# These will be restored by tl.reshape.
singleton_dims = [
sizevars.statically_known_equals(dim, 1) for dim in params.block_shape
]
if all(singleton_dims):
# Handle a pure singletons, e.g. [1, 1]
singleton_dims[-1] = False
# Record the post-broadcast shape before broadcasting dims are removed.
# The pre-broadcast shape is identical to this, except broadcasting dims are
# replaced with 1.
broadcast_shape = [
dim
for dim, is_singleton in zip(params.block_shape, singleton_dims)
if not is_singleton
]
# Combine all removable dims.
removable_dims = [any(dims) for dims in zip(singleton_dims, broadcasting_dims)]
def remove_dims(it):
"""Removes any broadcasting or singleton dims from a given sequence"""
return [
item
for item, is_removable in zip(it, removable_dims)
if not is_removable
]
# Drop removable dimensions from the input.
params = BlockParameters(
**{key: remove_dims(val) for key, val in dataclasses.asdict(params).items()}
)
# Compute the final shape, adjusting for special kernel types.
final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees]
if V.kernel.no_x_dim:
assert range_trees[0].prefix == "x"
reshape_suffix.pop(0)
final_shape.pop(0)
if (
not V.kernel.inside_reduction
@ -220,42 +326,23 @@ class BlockPtrOptions:
and V.kernel.numels[-1] != 1
):
# Need to expand rank by 1 to match rank when self.inside_reduction=True
reshape_suffix.append("1")
def filter(it):
"""Removes any broadcasting dims from a given sequence"""
assert len(it) == len(broadcasting_dim)
return [
item
for item, is_broadcasting in zip(it, broadcasting_dim)
if not is_broadcasting or not drop_broadcasts
]
# Drop broadcasting dimensions from the input.
params = BlockParameters(
**{key: filter(val) for key, val in dataclasses.asdict(params).items()}
)
def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]:
return [V.graph.sizevars.lookup_precomputed_size(expr) for expr in exprs]
# Look up precomputed sizes
params.shape = lookup_size(params.shape)
params.strides = lookup_size(params.strides)
final_shape.append(sympy.Integer(1))
return BlockPtrOptions(
params=params,
constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
order=list(reversed(range(len(params.shape)))),
mask_vars=mask_vars,
reshape_suffix=reshape_suffix,
final_shape=final_shape,
broadcast_shape=broadcast_shape,
broadcasting_dims=broadcasting_dims,
)
def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr:
"""
Replaces instances of roffset with the new expression.
"""
roffset = block_offsets[SymT.RINDEX]
roffset = TritonSymbols.block_offsets[SymT.RINDEX]
return sympy_subs(expr, {roffset: replacement})
def format(self, name: str, roffset=True) -> str:
@ -296,7 +383,7 @@ class BlockPtrOptions:
# This works in multiple_of checks because block sizes are powers of 2.
block_to_max: Dict[sympy.Expr, Any] = {
block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()]
for symt, block_size in block_sizes.items()
for symt, block_size in TritonSymbols.block_sizes.items()
}
return [
@ -314,7 +401,7 @@ class BlockPtrOptions:
)
and not (
V.kernel.no_x_dim
and self.block_shape[idx] == block_sizes[SymT.XBLOCK]
and self.block_shape[idx] == TritonSymbols.block_sizes[SymT.XBLOCK]
)
)
]
@ -328,7 +415,7 @@ class BlockPtrOptions:
Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first
iteration has roffset=0, while the second has roffset=RBLOCK.
"""
rblock = block_sizes[SymT.RINDEX]
rblock = TritonSymbols.block_sizes[SymT.RINDEX]
advance = [
(
self.replace_roffset(offset, rblock)
@ -354,9 +441,19 @@ class BlockPtrOptions:
return bool(self.boundary_check())
def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
def triton_reshape(
value: str, old_shape: List[sympy.Expr], new_shape: List[sympy.Expr]
):
"""Workaround https://github.com/openai/triton/issues/2836"""
assert isinstance(old_shape, list) and isinstance(new_shape, list)
def shape_to_str(shape: List[sympy.Expr]) -> List[str]:
return [str(dim) for dim in shape]
old_shape, new_shape = tuple(
shape_to_str(shape) for shape in (old_shape, new_shape)
)
if old_shape == new_shape:
return value
if [s for s in new_shape if s != "1"] != old_shape:
@ -387,12 +484,10 @@ class TritonPrinter(PythonPrinter):
)
def _print_Float(self, expr):
# Use a tensor here to get float64. Otherwise the constant is
# truncated to float32.
if config.is_fbcode() and torch.version.hip:
ret = f"{expr}"
else:
ret = f"tl.full([1], {expr}, tl.float64)"
ret = f"tl.full([], {expr}, tl.float64)"
return ret
def _print_ToFloat(self, expr):
@ -1236,19 +1331,6 @@ class TritonKernel(SIMDKernel):
self.codegen_range_tree()
def _get_symt(self, tree: IterationRangesEntry) -> SymT:
prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()}
return prefix_to_symt[tree.prefix]
def _get_block_size(self, tree: IterationRangesEntry) -> sympy.Symbol:
return block_sizes[self._get_symt(tree)]
def _get_block_offset(self, tree: IterationRangesEntry) -> sympy.Symbol:
return block_offsets[self._get_symt(tree)]
def _max_block_size(self, tree: IterationRangesEntry) -> int:
return TRITON_MAX_BLOCK[tree.prefix.upper()]
def codegen_range_tree(self):
for tree in self.range_trees:
# reduction indexing goes inside a loop
@ -1395,9 +1477,9 @@ class TritonKernel(SIMDKernel):
return BlockParameters(
shape=[range_tree.numel],
block_shape=[self._get_block_size(range_tree)],
block_shape=[TritonSymbols.get_block_size(range_tree)],
strides=[m[stride]],
offsets=[self._get_block_offset(range_tree)],
offsets=[TritonSymbols.get_block_offset(range_tree)],
)
def match_mod_div_block(
@ -1508,7 +1590,7 @@ class TritonKernel(SIMDKernel):
# with n and m integers, then either numel is a multiple of XBLOCK, or numel
# is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.)
# 2. Numels are multiples of the maximum possible block size.
max_block = self._max_block_size(range_tree)
max_block = TritonSymbols.max_block_size(range_tree)
if any(
not sizevars.statically_known_multiple_of(numel, max_block)
and not sizevars.statically_known_power_of_2(numel)
@ -1524,7 +1606,7 @@ class TritonKernel(SIMDKernel):
# Non-leading dimensions are clamped to the size of the iteration range,
# while the leading dimension can exceed this to accomodate a larger
# block size.
linear_block_size = self._get_block_size(range_tree)
linear_block_size = TritonSymbols.get_block_size(range_tree)
block_shape: List[sympy.Expr] = [
CeilDiv(linear_block_size, slice_numels[0])
] + [
@ -1534,7 +1616,9 @@ class TritonKernel(SIMDKernel):
# Compute block offsets from {xyzr}offset and the matched expressions.
block_offsets: List[sympy.Expr] = [
sympy_subs(expr, {index_var: self._get_block_offset(range_tree)})
sympy_subs(
expr, {index_var: TritonSymbols.get_block_offset(range_tree)}
)
for expr in block_index_exprs
]
@ -1673,13 +1757,11 @@ class TritonKernel(SIMDKernel):
return block_ptr, advance_block_ptr, other
def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
# broadcasting is not implicit for block_ptrs
value = (
f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})"
# Stores require an explicit broadcast.
value = indexing.codegen_broadcast_and_reshape(
value, indexing.final_shape, indexing.block_shape, False
)
# drop any extra size=1 dimensions
block_shape = [V.kernel.index_to_str(expr) for expr in indexing.block_shape]
value = triton_reshape(value, indexing.reshape_suffix, block_shape)
# workaround https://github.com/openai/triton/issues/2814
value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
return f"tl.store({block_ptr}, {value}{other})"
@ -1787,9 +1869,10 @@ class TritonKernel(SIMDKernel):
name, var, indexing, other
)
line = f"tl.load({block_ptr}{other}{ep})"
# add needed size=1 dimensions
block_shape = [str(dim) for dim in indexing.block_shape]
line = triton_reshape(line, block_shape, indexing.reshape_suffix)
line = indexing.codegen_broadcast_and_reshape(
line, indexing.block_shape, indexing.final_shape, True
)
elif isinstance(original_index, sympy.Integer):
line = f"tl.load({var} + ({original_index}))"
append_broadcast = indexing.expand_str

View File

@ -753,6 +753,8 @@ class CachingAutotuner(KernelInterface):
self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns)
def save_gpu_kernel(self, grid, stream, launcher):
if self.cuda_kernel_saved:
return
if callable(grid):
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
else: