Add host-side Triton TMA support to Dynamo (#137677)

This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.
- The newly introduced variables have `reconstruct` methods used in case of graph breaks.
- The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.
- In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors.
- In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.
- JIT Inductor and AOT Inductor support will be implemented in follow-up PRs.

Differential Revision: [D64404928](https://our.internmc.facebook.com/intern/diff/D64404928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137677
Approved by: https://github.com/zou3519
This commit is contained in:
Adnan Akhundov
2024-10-15 09:09:00 -07:00
committed by PyTorch MergeBot
parent dd2ae7d0c9
commit 809ff3b274
12 changed files with 680 additions and 25 deletions

View File

@ -11534,6 +11534,41 @@ fn
self.assertEqual(expected, actual)
self.assertGreater(po.call_count, 0)
def test_data_ptr_graph_break_builtin(self):
def f(a, b):
# builtin + not implemented for DataPtrVariable
return a.data_ptr() + b.data_ptr()
a = torch.randn(4)
b = torch.randn(5)
# make sure there is a graph break
with self.assertRaises(torch._dynamo.exc.Unsupported):
torch.compile(f, backend="eager", fullgraph=True)(a, b)
torch._dynamo.reset()
expected = f(a, b)
actual = torch.compile(f, backend="eager")(a, b)
self.assertEqual(expected, actual)
def test_data_ptr_graph_break_aten(self):
def f(a):
# torch.add not implemented for DataPtrVariable
return torch.add(a, a.data_ptr())
a = torch.randn(4)
counters.clear()
expected = f(a)
actual = torch.compile(f, backend="eager")(a)
self.assertEqual(expected, actual)
self.assertTrue(len(counters["graph_break"]) > 0)
counters.clear()
class AssertNumOutputBackend:
"""
A backend that checks the number of output for compiled graph, and

View File

@ -29,7 +29,7 @@ from torch.testing._internal.logging_utils import logs_to_string
# Defines all the kernels for tests
from torch.testing._internal.triton_utils import * # noqa: F403
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton_package, has_triton_tma
if HAS_GPU:
@ -99,6 +99,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
kernel_idx=add_kernel_id,
constant_args_idx=constant_args_idx,
grid=[grid],
tma_descriptor_metadata={},
kwargs={
"in_ptr0": t1,
"in_ptr1": t2,
@ -115,6 +116,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
kernel_idx=add_kernel_id,
constant_args_idx=constant_args_idx,
grid=[grid],
tma_descriptor_metadata={},
kwargs={
"in_ptr0": t1,
"in_ptr1": t2,
@ -145,6 +147,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
{"n_elements": output.numel(), "BLOCK_SIZE": 16}
),
grid=[(x.numel(),)],
tma_descriptor_metadata={},
kwargs={
"in_ptr0": x,
"out_ptr": output,
@ -173,7 +176,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
gm.code.strip(),
"""\
def forward(self, x_1, output_1):
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], tma_descriptor_metadata = {}, kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None
getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']; getitem = None
getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']; triton_kernel_wrapper_functional_proxy = None
return getitem_1""",
@ -217,6 +220,7 @@ def forward(self, x_1, output_1):
{"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
),
grid=[(x_func.numel(),)],
tma_descriptor_metadata={},
kwargs={
"ptr": x_func,
},
@ -238,6 +242,7 @@ def forward(self, x_1, output_1):
{"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
),
grid=[(x_func.numel(),)],
tma_descriptor_metadata={},
kwargs={
"ptr": x_func,
},
@ -1630,6 +1635,218 @@ def forward(self, x_1, output_1):
self.assertEqual(out2, x + y + 1)
self.assertEqual(out3, z**2)
@requires_gpu
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
@common_utils.parametrize("dynamic", [False, True])
def test_tma_capture_and_functionalize(self, dynamic):
def f(a, b):
BLOCK_SIZE = 256
out = torch.zeros_like(a)
n_elements = out.numel()
desc_a, desc_b, desc_out = (
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
t.data_ptr(),
n_elements,
BLOCK_SIZE,
t.element_size(),
)
for t in (a, b, out)
)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_with_tma_1d[grid](
desc_a,
desc_b,
desc_out,
BLOCK_SIZE=BLOCK_SIZE,
)
return out
a = torch.randn(301, device=GPU_TYPE)
b = torch.randn(301, device=GPU_TYPE)
backend = torch._dynamo.testing.AOTEagerAndRecordGraphs()
torch.compile(
f,
fullgraph=True,
backend=backend,
dynamic=dynamic,
)(a, b)
if dynamic:
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False)
add_2 = arg0_1 + 256
sub_1 = add_2 - 1; add_2 = None
floordiv = sub_1 // 256; sub_1 = None
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(floordiv, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([arg0_1], [256], 4), 'in_desc_ptr1': ([arg0_1], [256], 4), 'out_desc_ptr': ([arg0_1], [256], 4)}, kwargs = {'in_desc_ptr0': arg1_1, 'in_desc_ptr1': arg2_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); floordiv = arg0_1 = arg1_1 = arg2_1 = zeros_like = None
getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None
return (getitem,)""",
)
else:
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False)
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(2, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([301], [256], 4), 'in_desc_ptr1': ([301], [256], 4), 'out_desc_ptr': ([301], [256], 4)}, kwargs = {'in_desc_ptr0': arg0_1, 'in_desc_ptr1': arg1_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); arg0_1 = arg1_1 = zeros_like = None
getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None
return (getitem,)""",
)
@requires_gpu
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
@common_utils.parametrize("after_data_ptr", [False, True])
@common_utils.parametrize("after_create_desc", [False, True])
def test_tma_graph_breaks(self, after_data_ptr, after_create_desc):
def f(a, b):
BLOCK_SIZE = 256
out = torch.zeros_like(a)
n_elements = out.numel()
ptrs = [t.data_ptr() for t in (a, b, out)]
if after_data_ptr:
torch._dynamo.graph_break()
descs = [
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
ptr,
n_elements,
BLOCK_SIZE,
t.element_size(),
)
for ptr in ptrs
]
if after_create_desc:
torch._dynamo.graph_break()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_with_tma_1d[grid](
*descs,
BLOCK_SIZE=BLOCK_SIZE,
)
return out
a = torch.randn(301, device=GPU_TYPE)
b = torch.randn(301, device=GPU_TYPE)
expected_out = a + b
eager_out = f(a, b)
compiled_out = torch.compile(
f,
fullgraph=False,
backend="eager",
dynamic=False,
)(a, b)
self.assertEqual(eager_out, expected_out)
self.assertEqual(compiled_out, expected_out)
@requires_gpu
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager"])
def test_tma_descriptor_1d(self, dynamic, backend):
def f(a, b):
BLOCK_SIZE = 256
out = torch.zeros_like(a)
n_elements = out.numel()
desc_a, desc_b, desc_out = (
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
t.data_ptr(),
n_elements,
BLOCK_SIZE,
t.element_size(),
)
for t in (a, b, out)
)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_with_tma_1d[grid](
desc_a,
desc_b,
desc_out,
BLOCK_SIZE=BLOCK_SIZE,
)
return out
a = torch.randn(301, device=GPU_TYPE)
b = torch.randn(301, device=GPU_TYPE)
expected_out = a + b
eager_out = f(a, b)
compiled_out = torch.compile(
f,
fullgraph=True,
backend=backend,
dynamic=dynamic,
)(a, b)
self.assertEqual(eager_out, expected_out)
self.assertEqual(compiled_out, expected_out)
@requires_gpu
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager"])
def test_tma_descriptor_2d(self, dynamic, backend):
def f(a, b):
BLOCK_SIZE_X = 16
BLOCK_SIZE_Y = 32
out = torch.zeros_like(a)
x_size, y_size = out.size()
desc_a, desc_b, desc_out = (
triton.tools.experimental_descriptor.create_2d_tma_descriptor(
t.data_ptr(),
x_size,
y_size,
BLOCK_SIZE_X,
BLOCK_SIZE_Y,
t.element_size(),
)
for t in (a, b, out)
)
grid = lambda meta: (
triton.cdiv(x_size, meta["BLOCK_SIZE_X"]),
triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]),
)
add_kernel_with_tma_2d[grid](
desc_a,
desc_b,
desc_out,
BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
)
return out
a = torch.randn((25, 16), device=GPU_TYPE)
b = torch.randn((25, 16), device=GPU_TYPE)
expected_out = a + b
eager_out = f(a, b)
compiled_out = torch.compile(
f,
fullgraph=True,
backend=backend,
dynamic=dynamic,
)(a, b)
self.assertEqual(eager_out, expected_out)
self.assertEqual(compiled_out, expected_out)
@requires_gpu
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_num_ctas(self, backend):

View File

@ -120,11 +120,21 @@ def boxed_nop(fx_g, example_inputs):
# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
aot_eager = aot_autograd(
fw_compiler=boxed_nop,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True,
)
def aot_eager(
gm,
fake_tensor_inputs,
fw_compiler=None,
bw_compiler=None,
**kwargs,
):
return aot_autograd(
fw_compiler=fw_compiler or boxed_nop,
bw_compiler=bw_compiler or boxed_nop,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True,
)(gm, fake_tensor_inputs, **kwargs)
register_backend(name="aot_eager", compiler_fn=aot_eager)
aot_eager_default_partitioner = aot_autograd(

View File

@ -24,6 +24,7 @@ from unittest.mock import patch
import torch
from torch import fx
from torch._dynamo.backends.debugging import aot_eager
from torch._dynamo.output_graph import OutputGraph
from . import config, eval_frame, optimize_assert, reset
@ -245,6 +246,27 @@ class EagerAndRecordGraphs:
return gm.forward
# Equivalent to backend="aot_eager", but also records graphs that
# we can assert on
class AOTEagerAndRecordGraphs:
def __init__(self) -> None:
self.graphs: List[torch.fx.GraphModule] = []
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
def save_graph(gm: torch.fx.GraphModule, *args: Any, **kwargs: Any) -> Any:
self.graphs.append(gm)
return gm.forward
return aot_eager(
gm,
example_inputs,
fw_compiler=save_graph,
bw_compiler=save_graph,
)
def strip_comment(code: str) -> str:
return re.sub(r"(?m)^ *#.*\n?", "", code)

View File

@ -30,10 +30,12 @@ from .dicts import (
)
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
from .functions import (
CreateTMADescriptorVariable,
FunctoolsPartialVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
TMADescriptorVariable,
UserFunctionVariable,
UserMethodVariable,
)
@ -95,6 +97,7 @@ from .nn_module import (
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
DataPtrVariable,
FakeItemVariable,
NumpyNdarrayVariable,
SymNodeVariable,
@ -124,9 +127,11 @@ __all__ = [
"ConstDictVariable",
"ContextWrappingVariable",
"CountIteratorVariable",
"CreateTMADescriptorVariable",
"CUDADeviceVariable",
"CustomizedDictVariable",
"CycleIteratorVariable",
"DataPtrVariable",
"DefaultDictVariable",
"DeletedVariable",
"DeterministicAlgorithmsVariable",
@ -163,6 +168,7 @@ __all__ = [
"StringFormatVariable",
"SuperVariable",
"TensorVariable",
"TMADescriptorVariable",
"TorchCtxManagerClassVariable",
"TorchInGraphFunctionVariable",
"TorchVersionVariable",

View File

@ -141,6 +141,7 @@ from .distributed import (
)
from .functions import (
CollectiveFunctionRewriteVariable,
CreateTMADescriptorVariable,
FunctoolsPartialVariable,
TritonKernelVariable,
UserFunctionVariable,
@ -525,7 +526,7 @@ class VariableBuilder:
def _wrap(self, value):
# import here to avoid circular dependencies
from torch.utils._triton import has_triton
from torch.utils._triton import has_triton, has_triton_tma
if has_triton():
from triton.runtime.autotuner import Autotuner
@ -538,6 +539,19 @@ class VariableBuilder:
class Autotuner:
pass
if has_triton_tma():
from triton.tools.experimental_descriptor import (
create_1d_tma_descriptor,
create_2d_tma_descriptor,
)
else:
def create_1d_tma_descriptor():
pass
def create_2d_tma_descriptor():
pass
# Handle exact type() match
type_dispatch = self._type_dispatch().get(type(value))
if type_dispatch is not None:
@ -967,6 +981,10 @@ class VariableBuilder:
None, # No grid provided
source=self.source,
)
elif value is create_1d_tma_descriptor:
return CreateTMADescriptorVariable(rank=1)
elif value is create_2d_tma_descriptor:
return CreateTMADescriptorVariable(rank=2)
elif isinstance(value, torch.amp.autocast_mode.autocast):
self.install_guards(GuardBuilder.ID_MATCH)
return AutocastModeVariable(

View File

@ -1038,7 +1038,10 @@ class PolyfilledFunctionVariable(VariableTracker):
return self.fn
from torch._higher_order_ops.triton_kernel_wrap import TritonHOPifier
from torch._higher_order_ops.triton_kernel_wrap import (
TMADescriptorMetadata,
TritonHOPifier,
)
class DynamoTritonHOPifier(TritonHOPifier):
@ -1070,6 +1073,18 @@ class DynamoTritonHOPifier(TritonHOPifier):
from .constant import ConstantVariable
from .dicts import ConstDictVariable
# as we can only pass tensors as non-const args in fx graph,
# here we replace TMA descriptors (TMADescriptorVariable
# instances) with the underlying tensors, while moving the
# TMA descriptor-related metadata to a separate argument,
# so that we can reconstruct the TMA descriptors downstream
tma_descriptor_metadata: TMADescriptorMetadata = {}
for k in list(combined_args_raw.keys()):
v = combined_args_raw[k]
if isinstance(v, TMADescriptorVariable):
tma_descriptor_metadata[k] = v.to_metadata()
combined_args_raw[k] = v.data_ptr.from_tensor
combined_args = {
variables.ConstantVariable.create(k): v
for k, v in combined_args_raw.items()
@ -1094,6 +1109,13 @@ class DynamoTritonHOPifier(TritonHOPifier):
if not isinstance(v, ConstantVariable)
}
for v in non_constant_args.values():
v = v.realize()
if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)):
self.raise_unsupported(
f"Unexpected argument type for a Triton kernel: {repr(v)}."
)
constant_args_idx = kernel_side_table.add_constant_args(constant_args)
meta = ConstDictVariable(non_constant_args, dict)
tx.output.create_proxy(
@ -1104,6 +1126,7 @@ class DynamoTritonHOPifier(TritonHOPifier):
"kernel_idx": variable.kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grids,
"tma_descriptor_metadata": tma_descriptor_metadata,
"kwargs": meta.as_proxy(),
},
)
@ -1154,3 +1177,93 @@ class TritonKernelVariable(VariableTracker):
if isinstance(arg, SymNodeVariable):
return ConstantVariable.create(arg.evaluate_expr())
return arg
class TMADescriptorVariable(VariableTracker):
def __init__(
self,
data_ptr: "variables.DataPtrVariable",
dims: "List[ConstantVariable]",
block_dims: "List[ConstantVariable]",
element_size: "ConstantVariable",
**kwargs,
):
assert isinstance(data_ptr, variables.DataPtrVariable)
super().__init__(**kwargs),
self.data_ptr = data_ptr
self.dims = dims
self.block_dims = block_dims
self.element_size = element_size
def to_metadata(self):
return (
[dim.as_proxy() for dim in self.dims],
[dim.as_proxy() for dim in self.block_dims],
self.element_size.as_proxy(),
)
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from(
"triton.tools.experimental_descriptor",
f"create_{len(self.dims)}d_tma_descriptor",
)
)
self.data_ptr.reconstruct(codegen)
args = [*self.dims, *self.block_dims, self.element_size]
codegen.foreach(args)
codegen.call_function(len(args) + 1, False)
class CreateTMADescriptorVariable(VariableTracker):
def __init__(
self,
rank: int,
**kwargs,
) -> None:
super().__init__(**kwargs),
assert rank in (1, 2)
self.rank = rank
def call_function(
self,
tx: "InstructionTranslator",
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
ptr = kwargs["ptr"] if "ptr" in kwargs else args[0]
if not isinstance(ptr, variables.DataPtrVariable):
raise Unsupported(
"Please ensure there were no graph breaks between "
f"create_{self.rank}d_tma_descriptor and the upstream "
".data_ptr() call."
)
if self.rank == 1:
assert len(args) + len(kwargs) == 4
dims = [
kwargs["dim"] if "dim" in kwargs else args[1],
]
block_dims = [
kwargs["block_dim"] if "block_dim" in kwargs else args[2],
]
else:
assert len(args) + len(kwargs) == 6
dims = [
kwargs["dim1"] if "dim1" in kwargs else args[1],
kwargs["dim0"] if "dim0" in kwargs else args[2],
]
block_dims = [
kwargs["block_dim1"] if "block_dim1" in kwargs else args[3],
kwargs["block_dim2"] if "block_dim2" in kwargs else args[4],
]
element_size = kwargs["ptr"] if "ptr" in kwargs else args[-1]
return TMADescriptorVariable(
data_ptr=ptr,
dims=dims,
block_dims=block_dims,
element_size=element_size,
)

View File

@ -793,7 +793,7 @@ class TensorVariable(VariableTracker):
unimplemented("Tensor.backward")
def method_data_ptr(self, *args, **kwargs):
unimplemented("Tensor.data_ptr")
return DataPtrVariable(self)
def method_item(self, *args, **kwargs):
if not config.capture_scalar_outputs:
@ -1439,3 +1439,18 @@ class UntypedStorageVariable(VariableTracker):
codegen(self.from_tensor)
codegen.load_method("untyped_storage")
codegen.call_method(0)
class DataPtrVariable(VariableTracker):
def __init__(
self,
from_tensor: TensorVariable,
**kwargs,
) -> None:
super().__init__(**kwargs),
self.from_tensor = from_tensor
def reconstruct(self, codegen):
codegen(self.from_tensor, allow_cache=False)
codegen.load_method("data_ptr")
codegen.call_method(0)

View File

@ -6,11 +6,11 @@ import inspect
import logging
import threading
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch.fx as fx
import torch.utils._pytree as pytree
from torch import Tensor
from torch import SymInt, Tensor
from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
from torch._prims_common import clone_preserve_strides
@ -25,6 +25,21 @@ from torch.fx.experimental.symbolic_shapes import guard_scalar
log = logging.getLogger("torch._dynamo")
# TMADescriptorMetadata maps kernel parameter names to the metadata that allows
# reconstructing TMA descriptors from the underlying tensors (passed as kernel
# arguments in the fx graph, instead of the TMA descriptors). Namely: a tuple
# conisting of list of dims, list of block dims, and element size. E.g., for this
# call in host-side Triton TMA API ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``,
# the metadata will look like ``([50, 60], [32, 15], 4)``. All ints can be SymInts.
TMADescriptorMetadata = Dict[
str, # kernel parameter name
Tuple[
List[Union[int, SymInt]], # dims
List[Union[int, SymInt]], # block_dims
Union[int, SymInt], # element_size
],
]
###############################################################################
# Kernel Side Table
@ -422,7 +437,12 @@ def analyze_kernel_mutations(functions, fn_name, num_args):
# List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
# All the OPs that have MemWrite trait.
# What if Triton exposed this?
MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]}
MUTATION_OPS = {
"tt.store": [0],
"tt.atomic_cas": [0],
"tt.atomic_rmw": [0],
"tt.experimental_descriptor_store": [0],
}
# Ops that we want to bail out on
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
@ -524,11 +544,19 @@ class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_mutation", cacheable=False)
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs):
def __call__(
self,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
):
return super().__call__(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
tma_descriptor_metadata=tma_descriptor_metadata,
kwargs=kwargs,
)
@ -541,11 +569,20 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_functional", cacheable=False)
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone):
def __call__(
self,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
tensors_to_clone,
):
return super().__call__(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
tma_descriptor_metadata=tma_descriptor_metadata,
kwargs=kwargs,
tensors_to_clone=tensors_to_clone,
)
@ -556,7 +593,12 @@ triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_mutation_dense(
*, kernel_idx, constant_args_idx, grid, kwargs
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
):
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
@ -573,19 +615,56 @@ def triton_kernel_wrapper_mutation_dense(
exec(code, namespace)
grid_fn = namespace[fn_name]
if tma_descriptor_metadata:
from triton.tools.experimental_descriptor import ( # noqa: F401
create_1d_tma_descriptor,
create_2d_tma_descriptor,
)
# as we need to launch the kernel here, we "unwrap" the
# tma_descriptor_metadata, create the TMA descriptors
# from it, and replace the tensors in the kwargs by the
# correspoinding TMA descriptors before launching
kwargs = kwargs.copy()
for k, v in tma_descriptor_metadata.items():
tensor = kwargs[k]
dims, block_dims, element_size = v
create_tma_descriptor = (
create_1d_tma_descriptor if len(dims) == 1 else create_2d_tma_descriptor
)
kwargs[k] = create_tma_descriptor(
tensor.data_ptr(),
*dims,
*block_dims,
element_size,
)
kernel[grid_fn](**kwargs, **constant_args)
@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
def triton_kernel_wrapper_mutation_fake_tensor_mode(
mode, *, kernel_idx, constant_args_idx, grid, kwargs
mode,
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
):
with mode:
return None
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta)
def _(*, kernel_idx, constant_args_idx, grid, kwargs):
def _(
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
):
return None
@ -607,7 +686,13 @@ def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args):
@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
mode, *, kernel_idx, constant_args_idx, grid, kwargs
mode,
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
):
trace_triton_kernel_wrapper(
mode,
@ -616,6 +701,7 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
"kernel_idx": kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grid,
"tma_descriptor_metadata": tma_descriptor_metadata,
"kwargs": kwargs,
},
)
@ -631,7 +717,12 @@ def get_mutated_tensors(kernel_idx, constant_args_idx, kwargs):
@triton_kernel_wrapper_mutation.py_functionalize_impl
def triton_kernel_wrapper_mutation_functionalize(
ctx, kernel_idx, constant_args_idx, grid, kwargs
ctx,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
@ -646,6 +737,7 @@ def triton_kernel_wrapper_mutation_functionalize(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
tma_descriptor_metadata=tma_descriptor_metadata,
kwargs=unwrapped_kwargs,
tensors_to_clone=tensors_to_clone,
)
@ -667,7 +759,13 @@ def triton_kernel_wrapper_mutation_functionalize(
@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_functional_dense(
*, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
tensors_to_clone,
):
# TODO(oulgen): For performance reasons, we want to ensure that these
# `clone_preserve_strides` calls are never executed at runtime
@ -681,6 +779,7 @@ def triton_kernel_wrapper_functional_dense(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
tma_descriptor_metadata=tma_descriptor_metadata,
kwargs=kwargs,
)
return {key: val for key, val in kwargs.items() if key in tensors_to_clone}
@ -688,7 +787,14 @@ def triton_kernel_wrapper_functional_dense(
@triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
def triton_kernel_wrapper_functional_fake_tensor_mode(
mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
mode,
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
tensors_to_clone,
):
# TODO(oulgen): For performance reasons, we want to ensure that these
# `clone_preserve_strides` calls are never executed at runtime
@ -704,7 +810,14 @@ def triton_kernel_wrapper_functional_fake_tensor_mode(
@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
mode,
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
tensors_to_clone,
):
return trace_triton_kernel_wrapper(
mode,
@ -713,6 +826,7 @@ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
"kernel_idx": kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grid,
"tma_descriptor_metadata": tma_descriptor_metadata,
"kwargs": kwargs,
"tensors_to_clone": tensors_to_clone,
},
@ -721,7 +835,13 @@ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
@triton_kernel_wrapper_functional.py_functionalize_impl
def triton_kernel_wrapper_functional_functionalize(
ctx, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
ctx,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata: TMADescriptorMetadata,
kwargs,
tensors_to_clone,
):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
@ -729,6 +849,7 @@ def triton_kernel_wrapper_functional_functionalize(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
tma_descriptor_metadata=tma_descriptor_metadata,
kwargs=unwrapped_kwargs,
tensors_to_clone=tensors_to_clone,
)
@ -1057,6 +1178,9 @@ class TracingTritonHOPifier(TritonHOPifier):
kernel_idx=variable.kernel_idx,
constant_args_idx=constant_args_idx,
grid=grids,
# TMA descriptor capturing not yet
# supported in non-dynamo tracing
tma_descriptor_metadata={},
kwargs=graphable_args,
)

View File

@ -6322,7 +6322,14 @@ make_fallback(auto_functionalized)
@register_lowering(triton_kernel_wrapper_mutation)
def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs):
def triton_kernel_wrap_(
*,
kernel_idx,
constant_args_idx,
grid,
tma_descriptor_metadata,
kwargs,
):
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
constant_args = kernel_side_table.get_constant_args(constant_args_idx)

View File

@ -191,6 +191,71 @@ if has_triton():
output = (x + y) * scaling_factor
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_tma_1d(
in_desc_ptr0,
in_desc_ptr1,
out_desc_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE
a = tl._experimental_descriptor_load(
in_desc_ptr0,
[offset],
[BLOCK_SIZE],
tl.float32,
)
b = tl._experimental_descriptor_load(
in_desc_ptr1,
[offset],
[BLOCK_SIZE],
tl.float32,
)
output = a + b
tl._experimental_descriptor_store(
out_desc_ptr,
output,
[offset],
)
@triton.jit
def add_kernel_with_tma_2d(
in_desc_ptr0,
in_desc_ptr1,
out_desc_ptr,
BLOCK_SIZE_X: "tl.constexpr",
BLOCK_SIZE_Y: "tl.constexpr",
):
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
offset_x = pid_x * BLOCK_SIZE_X
offset_y = pid_y * BLOCK_SIZE_Y
x = tl._experimental_descriptor_load(
in_desc_ptr0,
[offset_x, offset_y],
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
tl.float32,
)
y = tl._experimental_descriptor_load(
in_desc_ptr1,
[offset_x, offset_y],
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
tl.float32,
)
output = x + y
tl._experimental_descriptor_store(
out_desc_ptr,
output,
[offset_x, offset_y],
)
@triton.jit
def mul2_kernel(
in_ptr0,

View File

@ -15,6 +15,29 @@ def has_triton_package() -> bool:
return False
@functools.lru_cache(None)
def has_triton_tma():
if has_triton_package():
import torch
if (
torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (9, 0)
and not torch.version.hip
):
try:
from triton.tools.experimental_descriptor import ( # noqa: F401
create_1d_tma_descriptor,
create_2d_tma_descriptor,
)
return True
except ImportError:
pass
return False
@functools.lru_cache(None)
def has_triton() -> bool:
if not has_triton_package():