mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add host-side Triton TMA support to Dynamo (#137677)
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes: - Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024. - To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created. - The newly introduced variables have `reconstruct` methods used in case of graph breaks. - The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like. - In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors. - In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required. - JIT Inductor and AOT Inductor support will be implemented in follow-up PRs. Differential Revision: [D64404928](https://our.internmc.facebook.com/intern/diff/D64404928) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137677 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd2ae7d0c9
commit
809ff3b274
@ -11534,6 +11534,41 @@ fn
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertGreater(po.call_count, 0)
|
||||
|
||||
def test_data_ptr_graph_break_builtin(self):
|
||||
def f(a, b):
|
||||
# builtin + not implemented for DataPtrVariable
|
||||
return a.data_ptr() + b.data_ptr()
|
||||
|
||||
a = torch.randn(4)
|
||||
b = torch.randn(5)
|
||||
|
||||
# make sure there is a graph break
|
||||
with self.assertRaises(torch._dynamo.exc.Unsupported):
|
||||
torch.compile(f, backend="eager", fullgraph=True)(a, b)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
expected = f(a, b)
|
||||
actual = torch.compile(f, backend="eager")(a, b)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_data_ptr_graph_break_aten(self):
|
||||
def f(a):
|
||||
# torch.add not implemented for DataPtrVariable
|
||||
return torch.add(a, a.data_ptr())
|
||||
|
||||
a = torch.randn(4)
|
||||
|
||||
counters.clear()
|
||||
|
||||
expected = f(a)
|
||||
actual = torch.compile(f, backend="eager")(a)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertTrue(len(counters["graph_break"]) > 0)
|
||||
counters.clear()
|
||||
|
||||
class AssertNumOutputBackend:
|
||||
"""
|
||||
A backend that checks the number of output for compiled graph, and
|
||||
|
||||
@ -29,7 +29,7 @@ from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
# Defines all the kernels for tests
|
||||
from torch.testing._internal.triton_utils import * # noqa: F403
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton_package, has_triton_tma
|
||||
|
||||
|
||||
if HAS_GPU:
|
||||
@ -99,6 +99,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
|
||||
kernel_idx=add_kernel_id,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=[grid],
|
||||
tma_descriptor_metadata={},
|
||||
kwargs={
|
||||
"in_ptr0": t1,
|
||||
"in_ptr1": t2,
|
||||
@ -115,6 +116,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
|
||||
kernel_idx=add_kernel_id,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=[grid],
|
||||
tma_descriptor_metadata={},
|
||||
kwargs={
|
||||
"in_ptr0": t1,
|
||||
"in_ptr1": t2,
|
||||
@ -145,6 +147,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
|
||||
{"n_elements": output.numel(), "BLOCK_SIZE": 16}
|
||||
),
|
||||
grid=[(x.numel(),)],
|
||||
tma_descriptor_metadata={},
|
||||
kwargs={
|
||||
"in_ptr0": x,
|
||||
"out_ptr": output,
|
||||
@ -173,7 +176,7 @@ class KernelTests(torch._inductor.test_case.TestCase):
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1, output_1):
|
||||
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None
|
||||
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], tma_descriptor_metadata = {}, kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None
|
||||
getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']; getitem = None
|
||||
getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']; triton_kernel_wrapper_functional_proxy = None
|
||||
return getitem_1""",
|
||||
@ -217,6 +220,7 @@ def forward(self, x_1, output_1):
|
||||
{"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
|
||||
),
|
||||
grid=[(x_func.numel(),)],
|
||||
tma_descriptor_metadata={},
|
||||
kwargs={
|
||||
"ptr": x_func,
|
||||
},
|
||||
@ -238,6 +242,7 @@ def forward(self, x_1, output_1):
|
||||
{"n_elements": x_func.numel(), "BLOCK_SIZE": 16}
|
||||
),
|
||||
grid=[(x_func.numel(),)],
|
||||
tma_descriptor_metadata={},
|
||||
kwargs={
|
||||
"ptr": x_func,
|
||||
},
|
||||
@ -1630,6 +1635,218 @@ def forward(self, x_1, output_1):
|
||||
self.assertEqual(out2, x + y + 1)
|
||||
self.assertEqual(out3, z**2)
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_tma_capture_and_functionalize(self, dynamic):
|
||||
def f(a, b):
|
||||
BLOCK_SIZE = 256
|
||||
out = torch.zeros_like(a)
|
||||
n_elements = out.numel()
|
||||
|
||||
desc_a, desc_b, desc_out = (
|
||||
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
||||
t.data_ptr(),
|
||||
n_elements,
|
||||
BLOCK_SIZE,
|
||||
t.element_size(),
|
||||
)
|
||||
for t in (a, b, out)
|
||||
)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
add_kernel_with_tma_1d[grid](
|
||||
desc_a,
|
||||
desc_b,
|
||||
desc_out,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
a = torch.randn(301, device=GPU_TYPE)
|
||||
b = torch.randn(301, device=GPU_TYPE)
|
||||
|
||||
backend = torch._dynamo.testing.AOTEagerAndRecordGraphs()
|
||||
torch.compile(
|
||||
f,
|
||||
fullgraph=True,
|
||||
backend=backend,
|
||||
dynamic=dynamic,
|
||||
)(a, b)
|
||||
|
||||
if dynamic:
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False)
|
||||
add_2 = arg0_1 + 256
|
||||
sub_1 = add_2 - 1; add_2 = None
|
||||
floordiv = sub_1 // 256; sub_1 = None
|
||||
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(floordiv, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([arg0_1], [256], 4), 'in_desc_ptr1': ([arg0_1], [256], 4), 'out_desc_ptr': ([arg0_1], [256], 4)}, kwargs = {'in_desc_ptr0': arg1_1, 'in_desc_ptr1': arg2_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); floordiv = arg0_1 = arg1_1 = arg2_1 = zeros_like = None
|
||||
getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None
|
||||
return (getitem,)""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False)
|
||||
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(2, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([301], [256], 4), 'in_desc_ptr1': ([301], [256], 4), 'out_desc_ptr': ([301], [256], 4)}, kwargs = {'in_desc_ptr0': arg0_1, 'in_desc_ptr1': arg1_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); arg0_1 = arg1_1 = zeros_like = None
|
||||
getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None
|
||||
return (getitem,)""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
||||
@common_utils.parametrize("after_data_ptr", [False, True])
|
||||
@common_utils.parametrize("after_create_desc", [False, True])
|
||||
def test_tma_graph_breaks(self, after_data_ptr, after_create_desc):
|
||||
def f(a, b):
|
||||
BLOCK_SIZE = 256
|
||||
out = torch.zeros_like(a)
|
||||
n_elements = out.numel()
|
||||
|
||||
ptrs = [t.data_ptr() for t in (a, b, out)]
|
||||
|
||||
if after_data_ptr:
|
||||
torch._dynamo.graph_break()
|
||||
|
||||
descs = [
|
||||
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
||||
ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE,
|
||||
t.element_size(),
|
||||
)
|
||||
for ptr in ptrs
|
||||
]
|
||||
|
||||
if after_create_desc:
|
||||
torch._dynamo.graph_break()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
add_kernel_with_tma_1d[grid](
|
||||
*descs,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
a = torch.randn(301, device=GPU_TYPE)
|
||||
b = torch.randn(301, device=GPU_TYPE)
|
||||
|
||||
expected_out = a + b
|
||||
eager_out = f(a, b)
|
||||
compiled_out = torch.compile(
|
||||
f,
|
||||
fullgraph=False,
|
||||
backend="eager",
|
||||
dynamic=False,
|
||||
)(a, b)
|
||||
|
||||
self.assertEqual(eager_out, expected_out)
|
||||
self.assertEqual(compiled_out, expected_out)
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_tma_descriptor_1d(self, dynamic, backend):
|
||||
def f(a, b):
|
||||
BLOCK_SIZE = 256
|
||||
out = torch.zeros_like(a)
|
||||
n_elements = out.numel()
|
||||
|
||||
desc_a, desc_b, desc_out = (
|
||||
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
||||
t.data_ptr(),
|
||||
n_elements,
|
||||
BLOCK_SIZE,
|
||||
t.element_size(),
|
||||
)
|
||||
for t in (a, b, out)
|
||||
)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
add_kernel_with_tma_1d[grid](
|
||||
desc_a,
|
||||
desc_b,
|
||||
desc_out,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
a = torch.randn(301, device=GPU_TYPE)
|
||||
b = torch.randn(301, device=GPU_TYPE)
|
||||
|
||||
expected_out = a + b
|
||||
eager_out = f(a, b)
|
||||
compiled_out = torch.compile(
|
||||
f,
|
||||
fullgraph=True,
|
||||
backend=backend,
|
||||
dynamic=dynamic,
|
||||
)(a, b)
|
||||
|
||||
self.assertEqual(eager_out, expected_out)
|
||||
self.assertEqual(compiled_out, expected_out)
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(not has_triton_tma(), "requires Triton TMA support")
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_tma_descriptor_2d(self, dynamic, backend):
|
||||
def f(a, b):
|
||||
BLOCK_SIZE_X = 16
|
||||
BLOCK_SIZE_Y = 32
|
||||
out = torch.zeros_like(a)
|
||||
x_size, y_size = out.size()
|
||||
|
||||
desc_a, desc_b, desc_out = (
|
||||
triton.tools.experimental_descriptor.create_2d_tma_descriptor(
|
||||
t.data_ptr(),
|
||||
x_size,
|
||||
y_size,
|
||||
BLOCK_SIZE_X,
|
||||
BLOCK_SIZE_Y,
|
||||
t.element_size(),
|
||||
)
|
||||
for t in (a, b, out)
|
||||
)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(x_size, meta["BLOCK_SIZE_X"]),
|
||||
triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]),
|
||||
)
|
||||
add_kernel_with_tma_2d[grid](
|
||||
desc_a,
|
||||
desc_b,
|
||||
desc_out,
|
||||
BLOCK_SIZE_X=BLOCK_SIZE_X,
|
||||
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
a = torch.randn((25, 16), device=GPU_TYPE)
|
||||
b = torch.randn((25, 16), device=GPU_TYPE)
|
||||
|
||||
expected_out = a + b
|
||||
eager_out = f(a, b)
|
||||
compiled_out = torch.compile(
|
||||
f,
|
||||
fullgraph=True,
|
||||
backend=backend,
|
||||
dynamic=dynamic,
|
||||
)(a, b)
|
||||
|
||||
self.assertEqual(eager_out, expected_out)
|
||||
self.assertEqual(compiled_out, expected_out)
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_triton_kernel_num_ctas(self, backend):
|
||||
|
||||
@ -120,11 +120,21 @@ def boxed_nop(fx_g, example_inputs):
|
||||
|
||||
# Useful for debugging purpose
|
||||
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
|
||||
aot_eager = aot_autograd(
|
||||
fw_compiler=boxed_nop,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
keep_inference_input_mutations=True,
|
||||
)
|
||||
def aot_eager(
|
||||
gm,
|
||||
fake_tensor_inputs,
|
||||
fw_compiler=None,
|
||||
bw_compiler=None,
|
||||
**kwargs,
|
||||
):
|
||||
return aot_autograd(
|
||||
fw_compiler=fw_compiler or boxed_nop,
|
||||
bw_compiler=bw_compiler or boxed_nop,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
keep_inference_input_mutations=True,
|
||||
)(gm, fake_tensor_inputs, **kwargs)
|
||||
|
||||
|
||||
register_backend(name="aot_eager", compiler_fn=aot_eager)
|
||||
|
||||
aot_eager_default_partitioner = aot_autograd(
|
||||
|
||||
@ -24,6 +24,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._dynamo.backends.debugging import aot_eager
|
||||
from torch._dynamo.output_graph import OutputGraph
|
||||
|
||||
from . import config, eval_frame, optimize_assert, reset
|
||||
@ -245,6 +246,27 @@ class EagerAndRecordGraphs:
|
||||
return gm.forward
|
||||
|
||||
|
||||
# Equivalent to backend="aot_eager", but also records graphs that
|
||||
# we can assert on
|
||||
class AOTEagerAndRecordGraphs:
|
||||
def __init__(self) -> None:
|
||||
self.graphs: List[torch.fx.GraphModule] = []
|
||||
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
) -> Callable[..., Any]:
|
||||
def save_graph(gm: torch.fx.GraphModule, *args: Any, **kwargs: Any) -> Any:
|
||||
self.graphs.append(gm)
|
||||
return gm.forward
|
||||
|
||||
return aot_eager(
|
||||
gm,
|
||||
example_inputs,
|
||||
fw_compiler=save_graph,
|
||||
bw_compiler=save_graph,
|
||||
)
|
||||
|
||||
|
||||
def strip_comment(code: str) -> str:
|
||||
return re.sub(r"(?m)^ *#.*\n?", "", code)
|
||||
|
||||
|
||||
@ -30,10 +30,12 @@ from .dicts import (
|
||||
)
|
||||
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
|
||||
from .functions import (
|
||||
CreateTMADescriptorVariable,
|
||||
FunctoolsPartialVariable,
|
||||
NestedUserFunctionVariable,
|
||||
PolyfilledFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
TMADescriptorVariable,
|
||||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
)
|
||||
@ -95,6 +97,7 @@ from .nn_module import (
|
||||
from .optimizer import OptimizerVariable
|
||||
from .sdpa import SDPAParamsVariable
|
||||
from .tensor import (
|
||||
DataPtrVariable,
|
||||
FakeItemVariable,
|
||||
NumpyNdarrayVariable,
|
||||
SymNodeVariable,
|
||||
@ -124,9 +127,11 @@ __all__ = [
|
||||
"ConstDictVariable",
|
||||
"ContextWrappingVariable",
|
||||
"CountIteratorVariable",
|
||||
"CreateTMADescriptorVariable",
|
||||
"CUDADeviceVariable",
|
||||
"CustomizedDictVariable",
|
||||
"CycleIteratorVariable",
|
||||
"DataPtrVariable",
|
||||
"DefaultDictVariable",
|
||||
"DeletedVariable",
|
||||
"DeterministicAlgorithmsVariable",
|
||||
@ -163,6 +168,7 @@ __all__ = [
|
||||
"StringFormatVariable",
|
||||
"SuperVariable",
|
||||
"TensorVariable",
|
||||
"TMADescriptorVariable",
|
||||
"TorchCtxManagerClassVariable",
|
||||
"TorchInGraphFunctionVariable",
|
||||
"TorchVersionVariable",
|
||||
|
||||
@ -141,6 +141,7 @@ from .distributed import (
|
||||
)
|
||||
from .functions import (
|
||||
CollectiveFunctionRewriteVariable,
|
||||
CreateTMADescriptorVariable,
|
||||
FunctoolsPartialVariable,
|
||||
TritonKernelVariable,
|
||||
UserFunctionVariable,
|
||||
@ -525,7 +526,7 @@ class VariableBuilder:
|
||||
|
||||
def _wrap(self, value):
|
||||
# import here to avoid circular dependencies
|
||||
from torch.utils._triton import has_triton
|
||||
from torch.utils._triton import has_triton, has_triton_tma
|
||||
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
@ -538,6 +539,19 @@ class VariableBuilder:
|
||||
class Autotuner:
|
||||
pass
|
||||
|
||||
if has_triton_tma():
|
||||
from triton.tools.experimental_descriptor import (
|
||||
create_1d_tma_descriptor,
|
||||
create_2d_tma_descriptor,
|
||||
)
|
||||
else:
|
||||
|
||||
def create_1d_tma_descriptor():
|
||||
pass
|
||||
|
||||
def create_2d_tma_descriptor():
|
||||
pass
|
||||
|
||||
# Handle exact type() match
|
||||
type_dispatch = self._type_dispatch().get(type(value))
|
||||
if type_dispatch is not None:
|
||||
@ -967,6 +981,10 @@ class VariableBuilder:
|
||||
None, # No grid provided
|
||||
source=self.source,
|
||||
)
|
||||
elif value is create_1d_tma_descriptor:
|
||||
return CreateTMADescriptorVariable(rank=1)
|
||||
elif value is create_2d_tma_descriptor:
|
||||
return CreateTMADescriptorVariable(rank=2)
|
||||
elif isinstance(value, torch.amp.autocast_mode.autocast):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return AutocastModeVariable(
|
||||
|
||||
@ -1038,7 +1038,10 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||
return self.fn
|
||||
|
||||
|
||||
from torch._higher_order_ops.triton_kernel_wrap import TritonHOPifier
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
TMADescriptorMetadata,
|
||||
TritonHOPifier,
|
||||
)
|
||||
|
||||
|
||||
class DynamoTritonHOPifier(TritonHOPifier):
|
||||
@ -1070,6 +1073,18 @@ class DynamoTritonHOPifier(TritonHOPifier):
|
||||
from .constant import ConstantVariable
|
||||
from .dicts import ConstDictVariable
|
||||
|
||||
# as we can only pass tensors as non-const args in fx graph,
|
||||
# here we replace TMA descriptors (TMADescriptorVariable
|
||||
# instances) with the underlying tensors, while moving the
|
||||
# TMA descriptor-related metadata to a separate argument,
|
||||
# so that we can reconstruct the TMA descriptors downstream
|
||||
tma_descriptor_metadata: TMADescriptorMetadata = {}
|
||||
for k in list(combined_args_raw.keys()):
|
||||
v = combined_args_raw[k]
|
||||
if isinstance(v, TMADescriptorVariable):
|
||||
tma_descriptor_metadata[k] = v.to_metadata()
|
||||
combined_args_raw[k] = v.data_ptr.from_tensor
|
||||
|
||||
combined_args = {
|
||||
variables.ConstantVariable.create(k): v
|
||||
for k, v in combined_args_raw.items()
|
||||
@ -1094,6 +1109,13 @@ class DynamoTritonHOPifier(TritonHOPifier):
|
||||
if not isinstance(v, ConstantVariable)
|
||||
}
|
||||
|
||||
for v in non_constant_args.values():
|
||||
v = v.realize()
|
||||
if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)):
|
||||
self.raise_unsupported(
|
||||
f"Unexpected argument type for a Triton kernel: {repr(v)}."
|
||||
)
|
||||
|
||||
constant_args_idx = kernel_side_table.add_constant_args(constant_args)
|
||||
meta = ConstDictVariable(non_constant_args, dict)
|
||||
tx.output.create_proxy(
|
||||
@ -1104,6 +1126,7 @@ class DynamoTritonHOPifier(TritonHOPifier):
|
||||
"kernel_idx": variable.kernel_idx,
|
||||
"constant_args_idx": constant_args_idx,
|
||||
"grid": grids,
|
||||
"tma_descriptor_metadata": tma_descriptor_metadata,
|
||||
"kwargs": meta.as_proxy(),
|
||||
},
|
||||
)
|
||||
@ -1154,3 +1177,93 @@ class TritonKernelVariable(VariableTracker):
|
||||
if isinstance(arg, SymNodeVariable):
|
||||
return ConstantVariable.create(arg.evaluate_expr())
|
||||
return arg
|
||||
|
||||
|
||||
class TMADescriptorVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
data_ptr: "variables.DataPtrVariable",
|
||||
dims: "List[ConstantVariable]",
|
||||
block_dims: "List[ConstantVariable]",
|
||||
element_size: "ConstantVariable",
|
||||
**kwargs,
|
||||
):
|
||||
assert isinstance(data_ptr, variables.DataPtrVariable)
|
||||
|
||||
super().__init__(**kwargs),
|
||||
self.data_ptr = data_ptr
|
||||
self.dims = dims
|
||||
self.block_dims = block_dims
|
||||
self.element_size = element_size
|
||||
|
||||
def to_metadata(self):
|
||||
return (
|
||||
[dim.as_proxy() for dim in self.dims],
|
||||
[dim.as_proxy() for dim in self.block_dims],
|
||||
self.element_size.as_proxy(),
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
"triton.tools.experimental_descriptor",
|
||||
f"create_{len(self.dims)}d_tma_descriptor",
|
||||
)
|
||||
)
|
||||
self.data_ptr.reconstruct(codegen)
|
||||
args = [*self.dims, *self.block_dims, self.element_size]
|
||||
codegen.foreach(args)
|
||||
codegen.call_function(len(args) + 1, False)
|
||||
|
||||
|
||||
class CreateTMADescriptorVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs),
|
||||
assert rank in (1, 2)
|
||||
self.rank = rank
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
ptr = kwargs["ptr"] if "ptr" in kwargs else args[0]
|
||||
|
||||
if not isinstance(ptr, variables.DataPtrVariable):
|
||||
raise Unsupported(
|
||||
"Please ensure there were no graph breaks between "
|
||||
f"create_{self.rank}d_tma_descriptor and the upstream "
|
||||
".data_ptr() call."
|
||||
)
|
||||
|
||||
if self.rank == 1:
|
||||
assert len(args) + len(kwargs) == 4
|
||||
dims = [
|
||||
kwargs["dim"] if "dim" in kwargs else args[1],
|
||||
]
|
||||
block_dims = [
|
||||
kwargs["block_dim"] if "block_dim" in kwargs else args[2],
|
||||
]
|
||||
else:
|
||||
assert len(args) + len(kwargs) == 6
|
||||
dims = [
|
||||
kwargs["dim1"] if "dim1" in kwargs else args[1],
|
||||
kwargs["dim0"] if "dim0" in kwargs else args[2],
|
||||
]
|
||||
block_dims = [
|
||||
kwargs["block_dim1"] if "block_dim1" in kwargs else args[3],
|
||||
kwargs["block_dim2"] if "block_dim2" in kwargs else args[4],
|
||||
]
|
||||
element_size = kwargs["ptr"] if "ptr" in kwargs else args[-1]
|
||||
|
||||
return TMADescriptorVariable(
|
||||
data_ptr=ptr,
|
||||
dims=dims,
|
||||
block_dims=block_dims,
|
||||
element_size=element_size,
|
||||
)
|
||||
|
||||
@ -793,7 +793,7 @@ class TensorVariable(VariableTracker):
|
||||
unimplemented("Tensor.backward")
|
||||
|
||||
def method_data_ptr(self, *args, **kwargs):
|
||||
unimplemented("Tensor.data_ptr")
|
||||
return DataPtrVariable(self)
|
||||
|
||||
def method_item(self, *args, **kwargs):
|
||||
if not config.capture_scalar_outputs:
|
||||
@ -1439,3 +1439,18 @@ class UntypedStorageVariable(VariableTracker):
|
||||
codegen(self.from_tensor)
|
||||
codegen.load_method("untyped_storage")
|
||||
codegen.call_method(0)
|
||||
|
||||
|
||||
class DataPtrVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
from_tensor: TensorVariable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs),
|
||||
self.from_tensor = from_tensor
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen(self.from_tensor, allow_cache=False)
|
||||
codegen.load_method("data_ptr")
|
||||
codegen.call_method(0)
|
||||
|
||||
@ -6,11 +6,11 @@ import inspect
|
||||
import logging
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch.fx as fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import Tensor
|
||||
from torch import SymInt, Tensor
|
||||
from torch._C import DispatchKey
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
@ -25,6 +25,21 @@ from torch.fx.experimental.symbolic_shapes import guard_scalar
|
||||
|
||||
log = logging.getLogger("torch._dynamo")
|
||||
|
||||
# TMADescriptorMetadata maps kernel parameter names to the metadata that allows
|
||||
# reconstructing TMA descriptors from the underlying tensors (passed as kernel
|
||||
# arguments in the fx graph, instead of the TMA descriptors). Namely: a tuple
|
||||
# conisting of list of dims, list of block dims, and element size. E.g., for this
|
||||
# call in host-side Triton TMA API ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``,
|
||||
# the metadata will look like ``([50, 60], [32, 15], 4)``. All ints can be SymInts.
|
||||
TMADescriptorMetadata = Dict[
|
||||
str, # kernel parameter name
|
||||
Tuple[
|
||||
List[Union[int, SymInt]], # dims
|
||||
List[Union[int, SymInt]], # block_dims
|
||||
Union[int, SymInt], # element_size
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Kernel Side Table
|
||||
@ -422,7 +437,12 @@ def analyze_kernel_mutations(functions, fn_name, num_args):
|
||||
# List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
|
||||
# All the OPs that have MemWrite trait.
|
||||
# What if Triton exposed this?
|
||||
MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]}
|
||||
MUTATION_OPS = {
|
||||
"tt.store": [0],
|
||||
"tt.atomic_cas": [0],
|
||||
"tt.atomic_rmw": [0],
|
||||
"tt.experimental_descriptor_store": [0],
|
||||
}
|
||||
# Ops that we want to bail out on
|
||||
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
|
||||
|
||||
@ -524,11 +544,19 @@ class TritonKernelWrapperMutation(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("triton_kernel_wrapper_mutation", cacheable=False)
|
||||
|
||||
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
):
|
||||
return super().__call__(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grid,
|
||||
tma_descriptor_metadata=tma_descriptor_metadata,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
@ -541,11 +569,20 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("triton_kernel_wrapper_functional", cacheable=False)
|
||||
|
||||
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone):
|
||||
def __call__(
|
||||
self,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
tensors_to_clone,
|
||||
):
|
||||
return super().__call__(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grid,
|
||||
tma_descriptor_metadata=tma_descriptor_metadata,
|
||||
kwargs=kwargs,
|
||||
tensors_to_clone=tensors_to_clone,
|
||||
)
|
||||
@ -556,7 +593,12 @@ triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
|
||||
|
||||
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def triton_kernel_wrapper_mutation_dense(
|
||||
*, kernel_idx, constant_args_idx, grid, kwargs
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
):
|
||||
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
|
||||
|
||||
@ -573,19 +615,56 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
exec(code, namespace)
|
||||
grid_fn = namespace[fn_name]
|
||||
|
||||
if tma_descriptor_metadata:
|
||||
from triton.tools.experimental_descriptor import ( # noqa: F401
|
||||
create_1d_tma_descriptor,
|
||||
create_2d_tma_descriptor,
|
||||
)
|
||||
|
||||
# as we need to launch the kernel here, we "unwrap" the
|
||||
# tma_descriptor_metadata, create the TMA descriptors
|
||||
# from it, and replace the tensors in the kwargs by the
|
||||
# correspoinding TMA descriptors before launching
|
||||
kwargs = kwargs.copy()
|
||||
for k, v in tma_descriptor_metadata.items():
|
||||
tensor = kwargs[k]
|
||||
dims, block_dims, element_size = v
|
||||
create_tma_descriptor = (
|
||||
create_1d_tma_descriptor if len(dims) == 1 else create_2d_tma_descriptor
|
||||
)
|
||||
kwargs[k] = create_tma_descriptor(
|
||||
tensor.data_ptr(),
|
||||
*dims,
|
||||
*block_dims,
|
||||
element_size,
|
||||
)
|
||||
|
||||
kernel[grid_fn](**kwargs, **constant_args)
|
||||
|
||||
|
||||
@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
|
||||
def triton_kernel_wrapper_mutation_fake_tensor_mode(
|
||||
mode, *, kernel_idx, constant_args_idx, grid, kwargs
|
||||
mode,
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
):
|
||||
with mode:
|
||||
return None
|
||||
|
||||
|
||||
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta)
|
||||
def _(*, kernel_idx, constant_args_idx, grid, kwargs):
|
||||
def _(
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
):
|
||||
return None
|
||||
|
||||
|
||||
@ -607,7 +686,13 @@ def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args):
|
||||
|
||||
@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
|
||||
def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
|
||||
mode, *, kernel_idx, constant_args_idx, grid, kwargs
|
||||
mode,
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
):
|
||||
trace_triton_kernel_wrapper(
|
||||
mode,
|
||||
@ -616,6 +701,7 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
|
||||
"kernel_idx": kernel_idx,
|
||||
"constant_args_idx": constant_args_idx,
|
||||
"grid": grid,
|
||||
"tma_descriptor_metadata": tma_descriptor_metadata,
|
||||
"kwargs": kwargs,
|
||||
},
|
||||
)
|
||||
@ -631,7 +717,12 @@ def get_mutated_tensors(kernel_idx, constant_args_idx, kwargs):
|
||||
|
||||
@triton_kernel_wrapper_mutation.py_functionalize_impl
|
||||
def triton_kernel_wrapper_mutation_functionalize(
|
||||
ctx, kernel_idx, constant_args_idx, grid, kwargs
|
||||
ctx,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
):
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
|
||||
@ -646,6 +737,7 @@ def triton_kernel_wrapper_mutation_functionalize(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grid,
|
||||
tma_descriptor_metadata=tma_descriptor_metadata,
|
||||
kwargs=unwrapped_kwargs,
|
||||
tensors_to_clone=tensors_to_clone,
|
||||
)
|
||||
@ -667,7 +759,13 @@ def triton_kernel_wrapper_mutation_functionalize(
|
||||
|
||||
@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def triton_kernel_wrapper_functional_dense(
|
||||
*, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
tensors_to_clone,
|
||||
):
|
||||
# TODO(oulgen): For performance reasons, we want to ensure that these
|
||||
# `clone_preserve_strides` calls are never executed at runtime
|
||||
@ -681,6 +779,7 @@ def triton_kernel_wrapper_functional_dense(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grid,
|
||||
tma_descriptor_metadata=tma_descriptor_metadata,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return {key: val for key, val in kwargs.items() if key in tensors_to_clone}
|
||||
@ -688,7 +787,14 @@ def triton_kernel_wrapper_functional_dense(
|
||||
|
||||
@triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
|
||||
def triton_kernel_wrapper_functional_fake_tensor_mode(
|
||||
mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
|
||||
mode,
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
tensors_to_clone,
|
||||
):
|
||||
# TODO(oulgen): For performance reasons, we want to ensure that these
|
||||
# `clone_preserve_strides` calls are never executed at runtime
|
||||
@ -704,7 +810,14 @@ def triton_kernel_wrapper_functional_fake_tensor_mode(
|
||||
|
||||
@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
|
||||
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
|
||||
mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
|
||||
mode,
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
tensors_to_clone,
|
||||
):
|
||||
return trace_triton_kernel_wrapper(
|
||||
mode,
|
||||
@ -713,6 +826,7 @@ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
|
||||
"kernel_idx": kernel_idx,
|
||||
"constant_args_idx": constant_args_idx,
|
||||
"grid": grid,
|
||||
"tma_descriptor_metadata": tma_descriptor_metadata,
|
||||
"kwargs": kwargs,
|
||||
"tensors_to_clone": tensors_to_clone,
|
||||
},
|
||||
@ -721,7 +835,13 @@ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
|
||||
|
||||
@triton_kernel_wrapper_functional.py_functionalize_impl
|
||||
def triton_kernel_wrapper_functional_functionalize(
|
||||
ctx, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
|
||||
ctx,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata: TMADescriptorMetadata,
|
||||
kwargs,
|
||||
tensors_to_clone,
|
||||
):
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
with ctx.redispatch_to_next():
|
||||
@ -729,6 +849,7 @@ def triton_kernel_wrapper_functional_functionalize(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grid,
|
||||
tma_descriptor_metadata=tma_descriptor_metadata,
|
||||
kwargs=unwrapped_kwargs,
|
||||
tensors_to_clone=tensors_to_clone,
|
||||
)
|
||||
@ -1057,6 +1178,9 @@ class TracingTritonHOPifier(TritonHOPifier):
|
||||
kernel_idx=variable.kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grids,
|
||||
# TMA descriptor capturing not yet
|
||||
# supported in non-dynamo tracing
|
||||
tma_descriptor_metadata={},
|
||||
kwargs=graphable_args,
|
||||
)
|
||||
|
||||
|
||||
@ -6322,7 +6322,14 @@ make_fallback(auto_functionalized)
|
||||
|
||||
|
||||
@register_lowering(triton_kernel_wrapper_mutation)
|
||||
def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs):
|
||||
def triton_kernel_wrap_(
|
||||
*,
|
||||
kernel_idx,
|
||||
constant_args_idx,
|
||||
grid,
|
||||
tma_descriptor_metadata,
|
||||
kwargs,
|
||||
):
|
||||
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
||||
|
||||
constant_args = kernel_side_table.get_constant_args(constant_args_idx)
|
||||
|
||||
@ -191,6 +191,71 @@ if has_triton():
|
||||
output = (x + y) * scaling_factor
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_with_tma_1d(
|
||||
in_desc_ptr0,
|
||||
in_desc_ptr1,
|
||||
out_desc_ptr,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE
|
||||
|
||||
a = tl._experimental_descriptor_load(
|
||||
in_desc_ptr0,
|
||||
[offset],
|
||||
[BLOCK_SIZE],
|
||||
tl.float32,
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
in_desc_ptr1,
|
||||
[offset],
|
||||
[BLOCK_SIZE],
|
||||
tl.float32,
|
||||
)
|
||||
|
||||
output = a + b
|
||||
|
||||
tl._experimental_descriptor_store(
|
||||
out_desc_ptr,
|
||||
output,
|
||||
[offset],
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_with_tma_2d(
|
||||
in_desc_ptr0,
|
||||
in_desc_ptr1,
|
||||
out_desc_ptr,
|
||||
BLOCK_SIZE_X: "tl.constexpr",
|
||||
BLOCK_SIZE_Y: "tl.constexpr",
|
||||
):
|
||||
pid_x = tl.program_id(axis=0)
|
||||
pid_y = tl.program_id(axis=1)
|
||||
offset_x = pid_x * BLOCK_SIZE_X
|
||||
offset_y = pid_y * BLOCK_SIZE_Y
|
||||
|
||||
x = tl._experimental_descriptor_load(
|
||||
in_desc_ptr0,
|
||||
[offset_x, offset_y],
|
||||
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
||||
tl.float32,
|
||||
)
|
||||
y = tl._experimental_descriptor_load(
|
||||
in_desc_ptr1,
|
||||
[offset_x, offset_y],
|
||||
[BLOCK_SIZE_X, BLOCK_SIZE_Y],
|
||||
tl.float32,
|
||||
)
|
||||
|
||||
output = x + y
|
||||
|
||||
tl._experimental_descriptor_store(
|
||||
out_desc_ptr,
|
||||
output,
|
||||
[offset_x, offset_y],
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def mul2_kernel(
|
||||
in_ptr0,
|
||||
|
||||
@ -15,6 +15,29 @@ def has_triton_package() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton_tma():
|
||||
if has_triton_package():
|
||||
import torch
|
||||
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability() >= (9, 0)
|
||||
and not torch.version.hip
|
||||
):
|
||||
try:
|
||||
from triton.tools.experimental_descriptor import ( # noqa: F401
|
||||
create_1d_tma_descriptor,
|
||||
create_2d_tma_descriptor,
|
||||
)
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton() -> bool:
|
||||
if not has_triton_package():
|
||||
|
||||
Reference in New Issue
Block a user