[user triton] dynamo support for new host-side TMA API (#155662)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155662
Approved by: https://github.com/aakhundov
ghstack dependencies: #155510
This commit is contained in:
David Berard
2025-06-11 17:42:16 -07:00
committed by PyTorch MergeBot
parent 9cced33c7c
commit 132babe7e0
4 changed files with 148 additions and 23 deletions

View File

@ -7,6 +7,11 @@ import unittest
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import IS_FBCODE
from torch.testing._internal.inductor_utils import requires_triton
from torch.utils._triton import (
has_triton_experimental_host_tma,
has_triton_tensor_descriptor_host_tma,
)
def _filter_instructions(instructions, opname):
@ -397,6 +402,52 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
inp = torch.randn(3)
self.assertEqual(gn(inp), inp + 3)
@requires_triton()
@unittest.skipIf(
not has_triton_experimental_host_tma(),
"Test requires triton.tools.experimental_descriptor API",
)
def test_tma_experimental_reconstruct(self):
import triton
def create_tma(tensor):
tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(
tensor.data_ptr(),
tensor.size(0),
tensor.size(1),
32,
32,
tensor.element_size(),
)
return tensor + 1, tma
x = torch.randn(128, 128, device="cuda")
ref = create_tma(x)
res = torch.compile(create_tma, backend="eager")(x)
self.assertEqual(ref[1].desc, res[1].desc)
@requires_triton()
@unittest.skipIf(
not has_triton_tensor_descriptor_host_tma(),
"Test requires triton.tools.tensor_descriptor API",
)
def test_tma_stable_reconstruct(self):
import triton
def create_tma(tensor):
tma = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
tensor,
[32, 32],
)
return tensor + 1, tma
x = torch.randn(128, 128, device="cuda")
ref = create_tma(x)
res = torch.compile(create_tma, backend="eager")(x)
self.assertEqual(ref, res)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -54,7 +54,8 @@ from .distributed import BackwardHookVariable, DistributedVariable, PlacementVar
from .functions import (
BuiltinMethodVariable,
CollectionsNamedTupleFunction,
CreateTMADescriptorVariable,
CreateTMADescriptorExperimentalVariable,
CreateTMADescriptorStableVariable,
FunctionDecoratedByContextlibContextManagerVariable,
FunctoolsPartialVariable,
FunctoolsWrapsVariable,
@ -63,7 +64,8 @@ from .functions import (
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
TMADescriptorVariable,
TMADescriptorExperimentalVariable,
TMADescriptorStableVariable,
UserFunctionVariable,
UserMethodVariable,
)
@ -157,7 +159,8 @@ __all__ = [
"ConstDictVariable",
"ContextWrappingVariable",
"CountIteratorVariable",
"CreateTMADescriptorVariable",
"CreateTMADescriptorExperimentalVariable",
"CreateTMADescriptorStableVariable",
"CUDADeviceVariable",
"CycleIteratorVariable",
"DataPtrVariable",
@ -198,7 +201,8 @@ __all__ = [
"SuperVariable",
"TemporarilyPopInterpreterStackCtxManagerVariable",
"TensorVariable",
"TMADescriptorVariable",
"TMADescriptorExperimentalVariable",
"TMADescriptorStableVariable",
"TorchCtxManagerClassVariable",
"TorchInGraphFunctionVariable",
"TorchVersionVariable",

View File

@ -189,7 +189,8 @@ from .functions import (
BuiltinMethodVariable,
CollectionsNamedTupleFunction,
CollectiveFunctionRewriteVariable,
CreateTMADescriptorVariable,
CreateTMADescriptorExperimentalVariable,
CreateTMADescriptorStableVariable,
FunctoolsPartialVariable,
FunctoolsWrapsVariable,
SysFunctionVariable,
@ -606,7 +607,11 @@ class VariableBuilder:
def _wrap(self, value):
# import here to avoid circular dependencies
from torch.utils._triton import has_triton, has_triton_tma
from torch.utils._triton import (
has_triton,
has_triton_experimental_host_tma,
has_triton_tensor_descriptor_host_tma,
)
from ..decorators import DynamoConfigPatchProxy
@ -621,18 +626,25 @@ class VariableBuilder:
class Autotuner:
pass
if has_triton_tma():
from triton.tools.experimental_descriptor import (
# default implementations, in case we don't have triton (or the wrong triton version)
def create_1d_tma_descriptor():
pass
def create_2d_tma_descriptor():
pass
class TensorDescriptor:
@staticmethod
def from_tensor():
pass
if has_triton_experimental_host_tma():
from triton.tools.experimental_descriptor import ( # noqa: F811
create_1d_tma_descriptor,
create_2d_tma_descriptor,
)
else:
def create_1d_tma_descriptor():
pass
def create_2d_tma_descriptor():
pass
if has_triton_tensor_descriptor_host_tma():
from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811
# Handle exact type() match
type_dispatch = self._type_dispatch().get(type(value))
@ -1111,9 +1123,11 @@ class VariableBuilder:
source=self.source,
)
elif value is create_1d_tma_descriptor:
return CreateTMADescriptorVariable(rank=1)
return CreateTMADescriptorExperimentalVariable(rank=1)
elif value is create_2d_tma_descriptor:
return CreateTMADescriptorVariable(rank=2)
return CreateTMADescriptorExperimentalVariable(rank=2)
elif value is TensorDescriptor.from_tensor:
return CreateTMADescriptorStableVariable()
elif isinstance(value, torch.amp.autocast_mode.autocast):
self.install_guards(GuardBuilder.ID_MATCH)
return AutocastModeVariable(

View File

@ -2059,16 +2059,19 @@ class DynamoTritonHOPifier(TritonHOPifier):
from .dicts import ConstDictVariable
# as we can only pass tensors as non-const args in fx graph,
# here we replace TMA descriptors (TMADescriptorVariable
# here we replace TMA descriptors
# (TMADescriptorExperimentalVariable and TMADescriptorStableVariable
# instances) with the underlying tensors, while moving the
# TMA descriptor-related metadata to a separate argument,
# so that we can reconstruct the TMA descriptors downstream
tma_descriptor_metadata: TMADescriptorMetadata = {}
for k in list(combined_args_raw.keys()):
v = combined_args_raw[k]
if isinstance(v, TMADescriptorVariable):
if isinstance(
v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable)
):
tma_descriptor_metadata[k] = v.to_metadata()
combined_args_raw[k] = v.data_ptr.from_tensor
combined_args_raw[k] = v.get_tensor()
combined_args = {
variables.ConstantVariable.create(k): v
@ -2170,7 +2173,7 @@ class TritonKernelVariable(VariableTracker):
return arg
class TMADescriptorVariable(VariableTracker):
class TMADescriptorExperimentalVariable(VariableTracker):
def __init__(
self,
data_ptr: "variables.DataPtrVariable",
@ -2205,8 +2208,45 @@ class TMADescriptorVariable(VariableTracker):
codegen.foreach(args)
codegen.call_function(len(args) + 1, False)
def get_tensor(self):
return self.data_ptr.from_tensor
class CreateTMADescriptorVariable(VariableTracker):
class TMADescriptorStableVariable(VariableTracker):
def __init__(
self,
tensor: "variables.TensorVariable",
block_shape: "variables.ListVariable",
**kwargs,
):
assert isinstance(tensor, variables.TensorVariable)
super().__init__(**kwargs)
self.tensor = tensor
self.block_shape = block_shape
def to_metadata(self):
# TODO(dberard) implement this
raise NotImplementedError(
"TensorDescriptor.from_tensor support is not yet implemented"
)
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(
"triton.tools.tensor_descriptor",
"TensorDescriptor",
)
)
codegen.load_method("from_tensor")
self.tensor.reconstruct(codegen)
codegen(self.block_shape)
codegen.call_method(2)
def get_tensor(self) -> "variables.TensorVariable":
return self.tensor
class CreateTMADescriptorExperimentalVariable(VariableTracker):
def __init__(
self,
rank: int,
@ -2251,9 +2291,25 @@ class CreateTMADescriptorVariable(VariableTracker):
]
element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
return TMADescriptorVariable(
return TMADescriptorExperimentalVariable(
data_ptr=ptr,
dims=dims,
block_dims=block_dims,
element_size=element_size,
)
class CreateTMADescriptorStableVariable(VariableTracker):
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tensor = kwargs["tensor"] if "tensor" in kwargs else args[0]
block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1]
return TMADescriptorStableVariable(
tensor=tensor,
block_shape=block_shape,
)