Compare commits

...

2 Commits

3 changed files with 22 additions and 6 deletions

View File

@ -240,6 +240,7 @@ class CommonTemplate:
test_torchinductor.skip_if_triton_cpu("Triton CPU: slow test")
],
),
((64,), (64,), None, None, True),
],
)
def test_pointwise(

View File

@ -23,6 +23,7 @@ import torch._logging
import torch.utils._pytree as pytree
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import identity, preserve_rng_state
from torch._inductor import shape_propagation
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
@ -2555,10 +2556,13 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
#
# To prevent unintended side effects we will gate options 1-3 behind isinstance(indexing, TensorDescriptorOptions).
if isinstance(indexing, TensorDescriptorOptions) and value.shape:
str_final_shape = tuple(symt.name for symt in indexing.final_shape)
if value.shape[::-1] == str_final_shape:
value = f"tl.trans({value})"
elif value.shape != str_final_shape:
str_final_shape = shape_propagation.cast_sym_block_shape(
indexing.final_shape
)
value_shape = tuple(value.shape)
if len(value_shape) > 1 and value_shape[::-1] == str_final_shape:
value = f"tl.trans({value}, {list(reversed(range(len(value.shape))))})"
elif value_shape != str_final_shape:
raise AssertionError(
"TMA store requires no broadcasting when a shape is provided"
)
@ -2756,7 +2760,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
line = indexing.codegen_broadcast_and_reshape(
line, indexing.block_shape, indexing.final_shape, True
)
shape = indexing.final_shape
shape = shape_propagation.cast_sym_block_shape(indexing.final_shape)
elif isinstance(original_index, sympy.Integer):
line = f"tl.load({var} + ({original_index}))"
append_broadcast = indexing.expand_str
@ -2830,6 +2834,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
for_store=True,
force=force,
)
indexing = self.indexing(
index,
dense_indexing=True,

View File

@ -1,6 +1,6 @@
import functools
from collections.abc import Sequence
from typing import Callable, Optional, Protocol, Union
from typing import Callable, cast, Optional, Protocol, Union
import sympy
@ -10,6 +10,16 @@ from .virtualized import OpsValue, V
BlockShapeType = Optional[Sequence[Union[int, str]]]
SymBlockShapeType = Optional[Sequence[Union[int, str, sympy.Expr]]]
def cast_sym_block_shape(sym_block_shape: SymBlockShapeType) -> BlockShapeType:
shape = sym_block_shape
if shape is not None:
shape = tuple(
V.kernel.index_to_str(s) if isinstance(s, sympy.Expr) else s for s in shape
)
return cast(BlockShapeType, shape)
class ShapeVar(Protocol):