Compare commits

...

4 Commits

Author SHA1 Message Date
507c69e20f Don't uselessly recompute axiom dict every static eval call (#135429)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135429
Approved by: https://github.com/isuruf
ghstack dependencies: #135137
2024-09-27 04:03:25 +00:00
285fa03b5e Deal with size oblivious before going into worker (#135137)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135137
Approved by: https://github.com/isuruf
2024-09-27 04:03:25 +00:00
86631eccda [Inductor] Remove stride-0 dimensions from more complex block pointers (#135557)
Related issue: #125077

### Feature
Inductor tries to remove dimensions with stride 0 from block pointers. Rather than loading with stride 0, it's more efficient to load a smaller block pointer, then use `tl.broadcast_to` to broadcast it up to the desired size. This already worked for simpler block pointers, but it was disabled for more complex block pointers which used `tl.reshape` to change the dimensionality after loading.

This PR generalizes the approach to work for all block pointers. The idea is to first reshape, adding singleton dimensions, then broadcast those singletons up to something larger, then reshape again to the final output shape. For readability, we emit this code only if it actually does something. Simpler loads will just have `tl.load`.

Here's an example of a complicated kernel that uses `reshape` -> `load` -> `reshape`. (The first reshape is actually the slice `[None,None,:]`).
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = (xindex // 8)
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
    tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tmp2.to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```

Before this PR, we would have stride-0 dimensions:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = (xindex // 8)
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
    tmp1 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr1, shape=[8, 1, 8], strides=[8, 0, 0], block_shape=[((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))], order=[2, 1, 0], offsets=[(xoffset // 8), 0, xoffset % 8]), boundary_check=[0], eviction_policy='evict_last'), [XBLOCK])
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```

Here's a simpler example where we use 2D tiling. In this case we don't actually need the broadcast. The broadcast is implied via a slice adding a new singleton dimension. This code is not changed by this PR, but it's important to know that we don't accidentally insert unnecessary broadcasts.
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 8
    xnumel = 8
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
    tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]
    tmp2 = tmp0 + tmp1
    tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tmp2.to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')
```
### Test Plan
Added a new expecttest to check the emitted code for broadcast addition. Looking at the test, we can see that stride 0 dimensions are removed. (This test generated the example kernels in the previous section.)

This change also removed a stride-0 dimension in an existing block pointer test. I updated the expected code accordingly.

Bonus: I noticed that the test parametrization for `config.prefer_nd_tiling` wasn't working as intended. It ended up always setting this option to `True`. Fixed it so we get the intended test coverage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135557
Approved by: https://github.com/shunting314, https://github.com/jansel

Co-authored-by: Yueming Hao <yhao@meta.com>
2024-09-27 04:01:40 +00:00
2c5f5e303a [inductor] Triton codegen: Use scalar when creating f64 constant instead of 1-element tensor (#136594)
Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this:

`tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])
`

https://github.com/pytorch/pytorch/pull/135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want?

Differential Revision: [D63465169](https://our.internmc.facebook.com/intern/diff/D63465169)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136594
Approved by: https://github.com/mengluy0125, https://github.com/jansel
2024-09-27 04:01:09 +00:00
6 changed files with 292 additions and 114 deletions

View File

@ -10002,6 +10002,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < Mod(s0, 3): False, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, False: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False, True: True}
> Right: {}
==> divisible: values don't match.
> Left: {Mod(s0, 3)}
> Right: {}
@ -10039,6 +10042,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {False: False, True: True}
> Right: {}
==> guards: values don't match.
> Left: [Eq(s0, 3)]
> Right: []
@ -10080,6 +10086,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {3 <= s0: True, s0 < 3: False}
> Right: {}
==> guards: values don't match.
> Left: [s0 >= 3]
> Right: []
@ -10112,6 +10121,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < PythonMod(u0, 3): False, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, False: False, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) <= 0: True, True: True}
> Right: {}
==> deferred_runtime_asserts: values don't match.
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
> Right: {}

View File

@ -12036,7 +12036,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.reshape(tl.load(block_ptr0, boundary_check=[3], padding_option='zero', eviction_policy='evict_last'), [XBLOCK, RBLOCK])
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
)

View File

@ -105,7 +105,7 @@ class TritonBlockPointerTest(InductorTestCase):
foo, *inputs, expected_num_block_pointers=expected_num_block_pointers
)
@parametrize("prefer_nd_tiling", [(False, True)])
@parametrize("prefer_nd_tiling", [False, True])
@parametrize(
"full_size,view_size,stride,offset,require_block_ptr",
[
@ -176,7 +176,7 @@ class TritonBlockPointerTest(InductorTestCase):
config_patches={"triton.prefer_nd_tiling": prefer_nd_tiling},
)
@parametrize("prefer_nd_tiling", [(False, True)])
@parametrize("prefer_nd_tiling", [False, True])
@parametrize(
"x_size,y_size",
[
@ -230,7 +230,59 @@ class TritonBlockPointerTest(InductorTestCase):
config_patches={"triton.prefer_nd_tiling": prefer_nd_tiling},
)
@parametrize("prefer_nd_tiling", [(False, True)])
@parametrize("prefer_nd_tiling", [False, True])
def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool):
"""
Test that we emit tl.broadcast_to instead of using strides of 0.
"""
full_shape = (8, 8)
col_shape = (full_shape[1], 1)
device = torch.device(GPU_TYPE)
full = torch.randn(full_shape).to(device)
col = torch.as_strided(full, col_shape, full.stride())
# Expect 3 block pointers: 2 inputs one output
result, (triton_code,) = self.run_and_compare(
torch.add,
full,
col,
expected_num_block_pointers=3,
config_patches={
"triton.prefer_nd_tiling": prefer_nd_tiling,
},
)
# Check the code for broadcasts.
# We shouldn't see any strides of 0.
load_lines, store_lines = tuple(
[line for line in triton_code.split("\n") if substr in line]
for substr in ("tl.load", "tl.store")
)
if prefer_nd_tiling:
self.assertExpectedInline(
"\n".join(load_lines),
"""\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]""", # noqa: B950
)
self.assertExpectedInline(
"\n".join(store_lines),
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tmp2.to(tl.float32), boundary_check=[0, 1])""", # noqa: B950
)
else:
self.assertExpectedInline(
"\n".join(load_lines),
"""\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
)
self.assertExpectedInline(
"\n".join(store_lines),
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tmp2.to(tl.float32), boundary_check=[0])""", # noqa: B950
)
@parametrize("prefer_nd_tiling", [False, True])
@parametrize(
"view_size,num_block_pointers,num_triton_kernels",
[

View File

@ -17,6 +17,7 @@ from typing import (
Dict,
Iterable,
List,
no_type_check,
Optional,
Sequence,
Tuple,
@ -29,7 +30,12 @@ import torch
import torch._logging
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.symbol import (
free_symbol_is_type,
prefix_str,
symbol_is_type,
SymT,
)
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
@ -41,6 +47,7 @@ from ..runtime.hints import ReductionHint
from ..runtime.runtime_utils import green_text, yellow_text
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
from ..utils import (
cache_on_self,
get_dtype_size,
IndentedBuffer,
Placeholder,
@ -106,6 +113,13 @@ class IterationRanges:
def symbol(self):
return sympy_index_symbol(self.name)
@property
@cache_on_self
@no_type_check
def symt(self) -> SymT:
prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()}
return prefix_to_symt[self.prefix]
class IterationRangesRoot(IterationRanges):
def __init__(

View File

@ -67,6 +67,7 @@ from .common import (
)
from .simd import (
constant_repr,
IterationRanges,
IterationRangesEntry,
IterationRangesRoot,
pexpr,
@ -129,16 +130,35 @@ def gen_common_triton_imports():
return imports.getvalue()
class TritonSymbols:
"""
Stores sympy.Symbol instances and constants associated with triton codegen.
"""
block_offsets = {
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}
block_sizes = {
symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True)
symt: sympy.Symbol(
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
}
@classmethod
def get_block_size(cls, tree: IterationRanges) -> sympy.Symbol:
return cls.block_sizes[tree.symt]
@classmethod
def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol:
return cls.block_offsets[tree.symt]
@classmethod
def max_block_size(cls, tree: IterationRanges) -> int:
return TRITON_MAX_BLOCK[tree.prefix.upper()]
@dataclasses.dataclass
class IndexingOptions:
@ -171,7 +191,9 @@ class BlockPtrOptions:
constant_offset: sympy.Expr
order: List[int]
mask_vars: OrderedSet[str]
reshape_suffix: List[str]
broadcast_shape: List[sympy.Expr]
broadcasting_dims: List[bool]
final_shape: List[sympy.Expr]
@property
def shape(self) -> List[sympy.Expr]:
@ -189,6 +211,50 @@ class BlockPtrOptions:
def offsets(self) -> List[sympy.Expr]:
return self.params.offsets
def codegen_broadcast_and_reshape(
self,
value: str,
initial_shape: List[sympy.Expr],
final_shape: List[sympy.Expr],
allow_implicit: bool,
) -> str:
"""
Generate a broadcast and a reshape for the block pointer.
This restores stride-0 dimensions which were removed from the block pointer.
"""
# Reshape to add singletons.
pre_broadcast_shape = [
sympy.Integer(1) if is_broadcasting else dim
for dim, is_broadcasting in zip(
self.broadcast_shape, self.broadcasting_dims
)
]
value = triton_reshape(value, initial_shape, pre_broadcast_shape)
# Broadcast singletons.
# For loads, we can often implicitly broadcast singleton dimensions.
# We need an explicit broadcast for stores, or if the final reshape does more
# than add singletons.
sizevars = V.graph.sizevars
if any(self.broadcasting_dims) and (
not allow_implicit
or len(pre_broadcast_shape) != len(final_shape)
or any(
not (
sizevars.statically_known_equals(pre_dim, 1)
or sizevars.statically_known_equals(pre_dim, post_dim)
)
for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape)
)
):
value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})"
# Reshape to the final shape.
value = triton_reshape(value, self.broadcast_shape, final_shape)
return value
@staticmethod
def create(
*,
@ -198,21 +264,61 @@ class BlockPtrOptions:
mask_vars: OrderedSet[str],
) -> BlockPtrOptions:
"""Helper to create a BlockPtrOptions instance"""
reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
# Only drop broadcast dims if the output has the same
# rank as the block. Otherwise, we will get shape errors.
drop_broadcasts = len(reshape_suffix) == len(params.strides)
sizevars = V.graph.sizevars
broadcasting_dim = [s == 0 for s in params.strides]
for i, is_broadcasting in enumerate(broadcasting_dim):
if is_broadcasting and drop_broadcasts:
# drop any stride==0 dimensions for performance
reshape_suffix[i] = "1"
def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]:
return [sizevars.lookup_precomputed_size(expr) for expr in exprs]
# Look up precomputed sizes
params.shape = lookup_size(params.shape)
params.strides = lookup_size(params.strides)
# Strip out dimensions of stride 0.
# These will be restored with tl.broadcast_to.
broadcasting_dims = [
sizevars.statically_known_equals(stride, 0) for stride in params.strides
]
# Strip out dimensions of size 1.
# These will be restored by tl.reshape.
singleton_dims = [
sizevars.statically_known_equals(dim, 1) for dim in params.block_shape
]
if all(singleton_dims):
# Handle a pure singletons, e.g. [1, 1]
singleton_dims[-1] = False
# Record the post-broadcast shape before broadcasting dims are removed.
# The pre-broadcast shape is identical to this, except broadcasting dims are
# replaced with 1.
broadcast_shape = [
dim
for dim, is_singleton in zip(params.block_shape, singleton_dims)
if not is_singleton
]
# Combine all removable dims.
removable_dims = [any(dims) for dims in zip(singleton_dims, broadcasting_dims)]
def remove_dims(it):
"""Removes any broadcasting or singleton dims from a given sequence"""
return [
item
for item, is_removable in zip(it, removable_dims)
if not is_removable
]
# Drop removable dimensions from the input.
params = BlockParameters(
**{key: remove_dims(val) for key, val in dataclasses.asdict(params).items()}
)
# Compute the final shape, adjusting for special kernel types.
final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees]
if V.kernel.no_x_dim:
assert range_trees[0].prefix == "x"
reshape_suffix.pop(0)
final_shape.pop(0)
if (
not V.kernel.inside_reduction
@ -220,42 +326,23 @@ class BlockPtrOptions:
and V.kernel.numels[-1] != 1
):
# Need to expand rank by 1 to match rank when self.inside_reduction=True
reshape_suffix.append("1")
def filter(it):
"""Removes any broadcasting dims from a given sequence"""
assert len(it) == len(broadcasting_dim)
return [
item
for item, is_broadcasting in zip(it, broadcasting_dim)
if not is_broadcasting or not drop_broadcasts
]
# Drop broadcasting dimensions from the input.
params = BlockParameters(
**{key: filter(val) for key, val in dataclasses.asdict(params).items()}
)
def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]:
return [V.graph.sizevars.lookup_precomputed_size(expr) for expr in exprs]
# Look up precomputed sizes
params.shape = lookup_size(params.shape)
params.strides = lookup_size(params.strides)
final_shape.append(sympy.Integer(1))
return BlockPtrOptions(
params=params,
constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
order=list(reversed(range(len(params.shape)))),
mask_vars=mask_vars,
reshape_suffix=reshape_suffix,
final_shape=final_shape,
broadcast_shape=broadcast_shape,
broadcasting_dims=broadcasting_dims,
)
def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr:
"""
Replaces instances of roffset with the new expression.
"""
roffset = block_offsets[SymT.RINDEX]
roffset = TritonSymbols.block_offsets[SymT.RINDEX]
return sympy_subs(expr, {roffset: replacement})
def format(self, name: str, roffset=True) -> str:
@ -296,7 +383,7 @@ class BlockPtrOptions:
# This works in multiple_of checks because block sizes are powers of 2.
block_to_max: Dict[sympy.Expr, Any] = {
block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()]
for symt, block_size in block_sizes.items()
for symt, block_size in TritonSymbols.block_sizes.items()
}
return [
@ -314,7 +401,7 @@ class BlockPtrOptions:
)
and not (
V.kernel.no_x_dim
and self.block_shape[idx] == block_sizes[SymT.XBLOCK]
and self.block_shape[idx] == TritonSymbols.block_sizes[SymT.XBLOCK]
)
)
]
@ -328,7 +415,7 @@ class BlockPtrOptions:
Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first
iteration has roffset=0, while the second has roffset=RBLOCK.
"""
rblock = block_sizes[SymT.RINDEX]
rblock = TritonSymbols.block_sizes[SymT.RINDEX]
advance = [
(
self.replace_roffset(offset, rblock)
@ -354,9 +441,19 @@ class BlockPtrOptions:
return bool(self.boundary_check())
def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
def triton_reshape(
value: str, old_shape: List[sympy.Expr], new_shape: List[sympy.Expr]
):
"""Workaround https://github.com/openai/triton/issues/2836"""
assert isinstance(old_shape, list) and isinstance(new_shape, list)
def shape_to_str(shape: List[sympy.Expr]) -> List[str]:
return [str(dim) for dim in shape]
old_shape, new_shape = tuple(
shape_to_str(shape) for shape in (old_shape, new_shape)
)
if old_shape == new_shape:
return value
if [s for s in new_shape if s != "1"] != old_shape:
@ -387,12 +484,10 @@ class TritonPrinter(PythonPrinter):
)
def _print_Float(self, expr):
# Use a tensor here to get float64. Otherwise the constant is
# truncated to float32.
if config.is_fbcode() and torch.version.hip:
ret = f"{expr}"
else:
ret = f"tl.full([1], {expr}, tl.float64)"
ret = f"tl.full([], {expr}, tl.float64)"
return ret
def _print_ToFloat(self, expr):
@ -1236,19 +1331,6 @@ class TritonKernel(SIMDKernel):
self.codegen_range_tree()
def _get_symt(self, tree: IterationRangesEntry) -> SymT:
prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()}
return prefix_to_symt[tree.prefix]
def _get_block_size(self, tree: IterationRangesEntry) -> sympy.Symbol:
return block_sizes[self._get_symt(tree)]
def _get_block_offset(self, tree: IterationRangesEntry) -> sympy.Symbol:
return block_offsets[self._get_symt(tree)]
def _max_block_size(self, tree: IterationRangesEntry) -> int:
return TRITON_MAX_BLOCK[tree.prefix.upper()]
def codegen_range_tree(self):
for tree in self.range_trees:
# reduction indexing goes inside a loop
@ -1395,9 +1477,9 @@ class TritonKernel(SIMDKernel):
return BlockParameters(
shape=[range_tree.numel],
block_shape=[self._get_block_size(range_tree)],
block_shape=[TritonSymbols.get_block_size(range_tree)],
strides=[m[stride]],
offsets=[self._get_block_offset(range_tree)],
offsets=[TritonSymbols.get_block_offset(range_tree)],
)
def match_mod_div_block(
@ -1508,7 +1590,7 @@ class TritonKernel(SIMDKernel):
# with n and m integers, then either numel is a multiple of XBLOCK, or numel
# is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.)
# 2. Numels are multiples of the maximum possible block size.
max_block = self._max_block_size(range_tree)
max_block = TritonSymbols.max_block_size(range_tree)
if any(
not sizevars.statically_known_multiple_of(numel, max_block)
and not sizevars.statically_known_power_of_2(numel)
@ -1524,7 +1606,7 @@ class TritonKernel(SIMDKernel):
# Non-leading dimensions are clamped to the size of the iteration range,
# while the leading dimension can exceed this to accomodate a larger
# block size.
linear_block_size = self._get_block_size(range_tree)
linear_block_size = TritonSymbols.get_block_size(range_tree)
block_shape: List[sympy.Expr] = [
CeilDiv(linear_block_size, slice_numels[0])
] + [
@ -1534,7 +1616,9 @@ class TritonKernel(SIMDKernel):
# Compute block offsets from {xyzr}offset and the matched expressions.
block_offsets: List[sympy.Expr] = [
sympy_subs(expr, {index_var: self._get_block_offset(range_tree)})
sympy_subs(
expr, {index_var: TritonSymbols.get_block_offset(range_tree)}
)
for expr in block_index_exprs
]
@ -1673,13 +1757,11 @@ class TritonKernel(SIMDKernel):
return block_ptr, advance_block_ptr, other
def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
# broadcasting is not implicit for block_ptrs
value = (
f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})"
# Stores require an explicit broadcast.
value = indexing.codegen_broadcast_and_reshape(
value, indexing.final_shape, indexing.block_shape, False
)
# drop any extra size=1 dimensions
block_shape = [V.kernel.index_to_str(expr) for expr in indexing.block_shape]
value = triton_reshape(value, indexing.reshape_suffix, block_shape)
# workaround https://github.com/openai/triton/issues/2814
value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
return f"tl.store({block_ptr}, {value}{other})"
@ -1787,9 +1869,10 @@ class TritonKernel(SIMDKernel):
name, var, indexing, other
)
line = f"tl.load({block_ptr}{other}{ep})"
# add needed size=1 dimensions
block_shape = [str(dim) for dim in indexing.block_shape]
line = triton_reshape(line, block_shape, indexing.reshape_suffix)
line = indexing.codegen_broadcast_and_reshape(
line, indexing.block_shape, indexing.final_shape, True
)
elif isinstance(original_index, sympy.Integer):
line = f"tl.load({var} + ({original_index}))"
append_broadcast = indexing.expand_str

View File

@ -1516,9 +1516,8 @@ def safe_expand(r):
@lru_cache(None)
def _maybe_evaluate_static_worker(
expr: sympy.Expr,
symbol_info: Tuple[Tuple[sympy.Symbol, ValueRanges, sympy.Integer, bool], ...],
symbol_info: Tuple[Tuple[sympy.Symbol, ValueRanges, sympy.Integer], ...],
unbacked_only: bool,
size_oblivious: bool
):
"""
This variant of ShapeEnv._maybe_evaluate_static has no dependence on
@ -1531,33 +1530,20 @@ def _maybe_evaluate_static_worker(
new_shape_env = {}
new_range_env = {}
for idx, sinfo in enumerate(symbol_info):
k, vr, val, is_size_like = sinfo
if isinstance(val, SingletonInt):
k, vr, hint = sinfo
if isinstance(hint, SingletonInt):
# Skip var_ranges logic for SingletonInt which is only used
# for jagged layout NestedTensors today
continue
if size_oblivious and is_size_like:
lower = max(2, vr.lower)
# Clamping size-oblivious to some quantity below sys.maxsize
# helps us determine that f(u0) != sys.maxsize, which is a
# test that is looking for sys.maxsize as a sentinel, but you
# don't really want to worry about it for unbacked SymInts.
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2 ** 48, vr.upper)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems.
if lower <= upper:
vr = ValueRanges(lower, upper)
else:
lower = vr.lower
# Don't do anything if we don't have a nontrivial lower bound
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower is -int_oo or
(unbacked_only and val is not None) or
(unbacked_only and hint is not None) or
not vr.is_int
):
new_range_env[k] = vr
@ -2637,6 +2623,7 @@ class ShapeEnv:
)
self.guards: List[ShapeGuard] = []
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
@ -4658,33 +4645,61 @@ class ShapeEnv:
expr = canonicalize_bool_expr(expr)
# Pattern matching
symbols = tuple(expr.free_symbols)
if axioms is None:
axioms = self.get_axioms(symbols, compute_hint=compute_hint)
subst = self.axioms
else:
subst = {}
for e in axioms:
if e.free_symbols.issubset(expr.free_symbols):
subst.update(dict(self.get_implications(self.simplify(e))))
expr = expr.xreplace(subst)
# TODO: compute hint might have gotten broken here
fs = expr.free_symbols
if not fs and (expr.is_number or expr.is_Boolean):
return expr
def adjust_vr(k, vr):
# Check if the range can solve it statically quickly
if not (size_oblivious and k in self.size_like):
return vr
lower = max(2, vr.lower)
# Clamping size-oblivious to some quantity below sys.maxsize
# helps us determine that f(u0) != sys.maxsize, which is a
# test that is looking for sys.maxsize as a sentinel, but you
# don't really want to worry about it for unbacked SymInts.
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2 ** 48, vr.upper)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems. When this happens, just ignore the
# preexisting upper bound
if lower > upper:
upper = max(lower, 2 ** 48)
return ValueRanges(lower, upper)
if var_to_range is None:
if size_oblivious: # micro-optimization
var_ranges = {k: adjust_vr(k, v) for k, v in self.var_to_range.items()}
else:
var_ranges = self.var_to_range
else:
var_ranges = dict(var_to_range)
var_ranges = {k: adjust_vr(k, v) for k, v in var_to_range}
out = bound_sympy(expr, var_ranges)
if out.is_singleton():
return out.lower
symbol_info = tuple(
(s, var_ranges.get(s), self.var_to_val.get(s), s in self.size_like)
(s, var_ranges.get(s), self.var_to_val.get(s))
for s in sorted(fs, key=lambda s: str(s)) # TODO: speed up sort?
)
r = _maybe_evaluate_static_worker(expr, symbol_info, unbacked_only, size_oblivious)
return r
return _maybe_evaluate_static_worker(expr, symbol_info, unbacked_only)
@_lru_cache
def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
@ -5421,6 +5436,7 @@ class ShapeEnv:
stack = CapturedTraceback.extract(skip=1)
guard = ShapeGuard(g, stack)
self.guards.append(guard)
self.axioms.update(dict(self.get_implications(self.simplify(g))))
else:
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
@ -5532,6 +5548,7 @@ class ShapeEnv:
# and the guard in question has no unbacked SymInts in front
ix = cands[-1] if cands else None
self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
self.axioms.update(dict(self.get_implications(self.simplify(expr))))
self.num_deferred_runtime_asserts += 1
self._update_version_counter()
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)