mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[user triton] dynamo support for new host-side TMA API (#155662)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155662 Approved by: https://github.com/aakhundov ghstack dependencies: #155510
This commit is contained in:
committed by
PyTorch MergeBot
parent
9cced33c7c
commit
132babe7e0
@ -7,6 +7,11 @@ import unittest
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import IS_FBCODE
|
||||
from torch.testing._internal.inductor_utils import requires_triton
|
||||
from torch.utils._triton import (
|
||||
has_triton_experimental_host_tma,
|
||||
has_triton_tensor_descriptor_host_tma,
|
||||
)
|
||||
|
||||
|
||||
def _filter_instructions(instructions, opname):
|
||||
@ -397,6 +402,52 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
||||
inp = torch.randn(3)
|
||||
self.assertEqual(gn(inp), inp + 3)
|
||||
|
||||
@requires_triton()
|
||||
@unittest.skipIf(
|
||||
not has_triton_experimental_host_tma(),
|
||||
"Test requires triton.tools.experimental_descriptor API",
|
||||
)
|
||||
def test_tma_experimental_reconstruct(self):
|
||||
import triton
|
||||
|
||||
def create_tma(tensor):
|
||||
tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(
|
||||
tensor.data_ptr(),
|
||||
tensor.size(0),
|
||||
tensor.size(1),
|
||||
32,
|
||||
32,
|
||||
tensor.element_size(),
|
||||
)
|
||||
return tensor + 1, tma
|
||||
|
||||
x = torch.randn(128, 128, device="cuda")
|
||||
|
||||
ref = create_tma(x)
|
||||
res = torch.compile(create_tma, backend="eager")(x)
|
||||
self.assertEqual(ref[1].desc, res[1].desc)
|
||||
|
||||
@requires_triton()
|
||||
@unittest.skipIf(
|
||||
not has_triton_tensor_descriptor_host_tma(),
|
||||
"Test requires triton.tools.tensor_descriptor API",
|
||||
)
|
||||
def test_tma_stable_reconstruct(self):
|
||||
import triton
|
||||
|
||||
def create_tma(tensor):
|
||||
tma = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
||||
tensor,
|
||||
[32, 32],
|
||||
)
|
||||
return tensor + 1, tma
|
||||
|
||||
x = torch.randn(128, 128, device="cuda")
|
||||
|
||||
ref = create_tma(x)
|
||||
res = torch.compile(create_tma, backend="eager")(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -54,7 +54,8 @@ from .distributed import BackwardHookVariable, DistributedVariable, PlacementVar
|
||||
from .functions import (
|
||||
BuiltinMethodVariable,
|
||||
CollectionsNamedTupleFunction,
|
||||
CreateTMADescriptorVariable,
|
||||
CreateTMADescriptorExperimentalVariable,
|
||||
CreateTMADescriptorStableVariable,
|
||||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
FunctoolsPartialVariable,
|
||||
FunctoolsWrapsVariable,
|
||||
@ -63,7 +64,8 @@ from .functions import (
|
||||
NestedUserFunctionVariable,
|
||||
PolyfilledFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
TMADescriptorVariable,
|
||||
TMADescriptorExperimentalVariable,
|
||||
TMADescriptorStableVariable,
|
||||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
)
|
||||
@ -157,7 +159,8 @@ __all__ = [
|
||||
"ConstDictVariable",
|
||||
"ContextWrappingVariable",
|
||||
"CountIteratorVariable",
|
||||
"CreateTMADescriptorVariable",
|
||||
"CreateTMADescriptorExperimentalVariable",
|
||||
"CreateTMADescriptorStableVariable",
|
||||
"CUDADeviceVariable",
|
||||
"CycleIteratorVariable",
|
||||
"DataPtrVariable",
|
||||
@ -198,7 +201,8 @@ __all__ = [
|
||||
"SuperVariable",
|
||||
"TemporarilyPopInterpreterStackCtxManagerVariable",
|
||||
"TensorVariable",
|
||||
"TMADescriptorVariable",
|
||||
"TMADescriptorExperimentalVariable",
|
||||
"TMADescriptorStableVariable",
|
||||
"TorchCtxManagerClassVariable",
|
||||
"TorchInGraphFunctionVariable",
|
||||
"TorchVersionVariable",
|
||||
|
@ -189,7 +189,8 @@ from .functions import (
|
||||
BuiltinMethodVariable,
|
||||
CollectionsNamedTupleFunction,
|
||||
CollectiveFunctionRewriteVariable,
|
||||
CreateTMADescriptorVariable,
|
||||
CreateTMADescriptorExperimentalVariable,
|
||||
CreateTMADescriptorStableVariable,
|
||||
FunctoolsPartialVariable,
|
||||
FunctoolsWrapsVariable,
|
||||
SysFunctionVariable,
|
||||
@ -606,7 +607,11 @@ class VariableBuilder:
|
||||
|
||||
def _wrap(self, value):
|
||||
# import here to avoid circular dependencies
|
||||
from torch.utils._triton import has_triton, has_triton_tma
|
||||
from torch.utils._triton import (
|
||||
has_triton,
|
||||
has_triton_experimental_host_tma,
|
||||
has_triton_tensor_descriptor_host_tma,
|
||||
)
|
||||
|
||||
from ..decorators import DynamoConfigPatchProxy
|
||||
|
||||
@ -621,18 +626,25 @@ class VariableBuilder:
|
||||
class Autotuner:
|
||||
pass
|
||||
|
||||
if has_triton_tma():
|
||||
from triton.tools.experimental_descriptor import (
|
||||
# default implementations, in case we don't have triton (or the wrong triton version)
|
||||
def create_1d_tma_descriptor():
|
||||
pass
|
||||
|
||||
def create_2d_tma_descriptor():
|
||||
pass
|
||||
|
||||
class TensorDescriptor:
|
||||
@staticmethod
|
||||
def from_tensor():
|
||||
pass
|
||||
|
||||
if has_triton_experimental_host_tma():
|
||||
from triton.tools.experimental_descriptor import ( # noqa: F811
|
||||
create_1d_tma_descriptor,
|
||||
create_2d_tma_descriptor,
|
||||
)
|
||||
else:
|
||||
|
||||
def create_1d_tma_descriptor():
|
||||
pass
|
||||
|
||||
def create_2d_tma_descriptor():
|
||||
pass
|
||||
if has_triton_tensor_descriptor_host_tma():
|
||||
from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811
|
||||
|
||||
# Handle exact type() match
|
||||
type_dispatch = self._type_dispatch().get(type(value))
|
||||
@ -1111,9 +1123,11 @@ class VariableBuilder:
|
||||
source=self.source,
|
||||
)
|
||||
elif value is create_1d_tma_descriptor:
|
||||
return CreateTMADescriptorVariable(rank=1)
|
||||
return CreateTMADescriptorExperimentalVariable(rank=1)
|
||||
elif value is create_2d_tma_descriptor:
|
||||
return CreateTMADescriptorVariable(rank=2)
|
||||
return CreateTMADescriptorExperimentalVariable(rank=2)
|
||||
elif value is TensorDescriptor.from_tensor:
|
||||
return CreateTMADescriptorStableVariable()
|
||||
elif isinstance(value, torch.amp.autocast_mode.autocast):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return AutocastModeVariable(
|
||||
|
@ -2059,16 +2059,19 @@ class DynamoTritonHOPifier(TritonHOPifier):
|
||||
from .dicts import ConstDictVariable
|
||||
|
||||
# as we can only pass tensors as non-const args in fx graph,
|
||||
# here we replace TMA descriptors (TMADescriptorVariable
|
||||
# here we replace TMA descriptors
|
||||
# (TMADescriptorExperimentalVariable and TMADescriptorStableVariable
|
||||
# instances) with the underlying tensors, while moving the
|
||||
# TMA descriptor-related metadata to a separate argument,
|
||||
# so that we can reconstruct the TMA descriptors downstream
|
||||
tma_descriptor_metadata: TMADescriptorMetadata = {}
|
||||
for k in list(combined_args_raw.keys()):
|
||||
v = combined_args_raw[k]
|
||||
if isinstance(v, TMADescriptorVariable):
|
||||
if isinstance(
|
||||
v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable)
|
||||
):
|
||||
tma_descriptor_metadata[k] = v.to_metadata()
|
||||
combined_args_raw[k] = v.data_ptr.from_tensor
|
||||
combined_args_raw[k] = v.get_tensor()
|
||||
|
||||
combined_args = {
|
||||
variables.ConstantVariable.create(k): v
|
||||
@ -2170,7 +2173,7 @@ class TritonKernelVariable(VariableTracker):
|
||||
return arg
|
||||
|
||||
|
||||
class TMADescriptorVariable(VariableTracker):
|
||||
class TMADescriptorExperimentalVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
data_ptr: "variables.DataPtrVariable",
|
||||
@ -2205,8 +2208,45 @@ class TMADescriptorVariable(VariableTracker):
|
||||
codegen.foreach(args)
|
||||
codegen.call_function(len(args) + 1, False)
|
||||
|
||||
def get_tensor(self):
|
||||
return self.data_ptr.from_tensor
|
||||
|
||||
class CreateTMADescriptorVariable(VariableTracker):
|
||||
|
||||
class TMADescriptorStableVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
tensor: "variables.TensorVariable",
|
||||
block_shape: "variables.ListVariable",
|
||||
**kwargs,
|
||||
):
|
||||
assert isinstance(tensor, variables.TensorVariable)
|
||||
super().__init__(**kwargs)
|
||||
self.tensor = tensor
|
||||
self.block_shape = block_shape
|
||||
|
||||
def to_metadata(self):
|
||||
# TODO(dberard) implement this
|
||||
raise NotImplementedError(
|
||||
"TensorDescriptor.from_tensor support is not yet implemented"
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
"triton.tools.tensor_descriptor",
|
||||
"TensorDescriptor",
|
||||
)
|
||||
)
|
||||
codegen.load_method("from_tensor")
|
||||
self.tensor.reconstruct(codegen)
|
||||
codegen(self.block_shape)
|
||||
codegen.call_method(2)
|
||||
|
||||
def get_tensor(self) -> "variables.TensorVariable":
|
||||
return self.tensor
|
||||
|
||||
|
||||
class CreateTMADescriptorExperimentalVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
@ -2251,9 +2291,25 @@ class CreateTMADescriptorVariable(VariableTracker):
|
||||
]
|
||||
element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
|
||||
|
||||
return TMADescriptorVariable(
|
||||
return TMADescriptorExperimentalVariable(
|
||||
data_ptr=ptr,
|
||||
dims=dims,
|
||||
block_dims=block_dims,
|
||||
element_size=element_size,
|
||||
)
|
||||
|
||||
|
||||
class CreateTMADescriptorStableVariable(VariableTracker):
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tensor = kwargs["tensor"] if "tensor" in kwargs else args[0]
|
||||
block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1]
|
||||
|
||||
return TMADescriptorStableVariable(
|
||||
tensor=tensor,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
Reference in New Issue
Block a user