mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Analyze triton kernels via tracing to determine mutations (#117300)
This PR adds TTIR lexing and parsing in order to analyze which of the user defined triton kernel inputs are mutated. Differential Revision: [D53165999](https://our.internmc.facebook.com/intern/diff/D53165999) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117300 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
2951bbf0f7
commit
47b5a6b05d
@ -52,6 +52,11 @@ junitparser==2.1.1
|
||||
#Pinned versions: 2.1.1
|
||||
#test that import:
|
||||
|
||||
lark==0.12.0
|
||||
#Description: parser
|
||||
#Pinned versions: 0.12.0
|
||||
#test that import:
|
||||
|
||||
librosa>=0.6.2 ; python_version < "3.11"
|
||||
#Description: A python package for music and audio analysis
|
||||
#Pinned versions: >=0.6.2
|
||||
|
@ -18,3 +18,4 @@ fsspec
|
||||
setuptools ; python_version >= "3.12"
|
||||
packaging
|
||||
optree>=0.9.1
|
||||
lark
|
||||
|
@ -7,6 +7,8 @@ import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
triton_kernel_wrapper_functional,
|
||||
@ -893,22 +895,159 @@ def forward(self, x_1, output_1):
|
||||
|
||||
class MutationTests(torch._dynamo.test_case.TestCase):
|
||||
@requires_cuda
|
||||
@requires_lark
|
||||
def test_find_mutations(self):
|
||||
from torch._higher_order_ops.triton_kernel_wrap import filter_non_mutated
|
||||
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
|
||||
|
||||
t = torch.randn(4)
|
||||
|
||||
tests = [
|
||||
[add_kernel, ["in_ptr0", "in_ptr1", "out_ptr"], ["out_ptr"]],
|
||||
[add_kernel_2d_autotuned, ["in_ptr0", "in_ptr1", "out_ptr"], ["out_ptr"]],
|
||||
# Cannot remove in_ptr0 since it is used in a external call
|
||||
[indirection_kernel, ["in_ptr0", "out_ptr"], ["in_ptr0", "out_ptr"]],
|
||||
[mul2_inplace_kernel, ["ptr"], ["ptr"]],
|
||||
[
|
||||
add_kernel,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"in_ptr1": t,
|
||||
"out_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
},
|
||||
["out_ptr"],
|
||||
],
|
||||
[
|
||||
add_kernel_out_of_order,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"n_elements": 4,
|
||||
"in_ptr1": t,
|
||||
"out_ptr": t,
|
||||
"BLOCK_SIZE": 4,
|
||||
},
|
||||
["out_ptr"],
|
||||
],
|
||||
[
|
||||
add_kernel_2d_autotuned,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"in_ptr1": t,
|
||||
"out_ptr": t,
|
||||
"x_elements": 4,
|
||||
"y_elements": 4,
|
||||
},
|
||||
["out_ptr"],
|
||||
],
|
||||
[
|
||||
indirection_kernel,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"out_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
"ACTIVATION": "mul2_inplace_kernel",
|
||||
},
|
||||
["in_ptr0", "out_ptr"],
|
||||
],
|
||||
[
|
||||
indirection_kernel,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"out_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
"ACTIVATION": "add_kernel",
|
||||
},
|
||||
# TODO(oulgen): Multiple functions is not implemented yet
|
||||
["in_ptr0", "out_ptr"],
|
||||
],
|
||||
[
|
||||
mul2_inplace_kernel,
|
||||
{"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4},
|
||||
["ptr"],
|
||||
],
|
||||
# Cant optimize since the kernel contains a tl.inline_asm_elementwise
|
||||
[
|
||||
inline_asm_kernel,
|
||||
{"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4},
|
||||
["X", "Y", "Z"],
|
||||
],
|
||||
[
|
||||
add_kernel_with_block_ptr,
|
||||
{
|
||||
"x_ptr": t,
|
||||
"y_ptr": t,
|
||||
"output_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
},
|
||||
["output_ptr"],
|
||||
],
|
||||
[
|
||||
add_kernel_with_import,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"in_ptr1": t,
|
||||
"out_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
},
|
||||
["out_ptr"],
|
||||
],
|
||||
[
|
||||
atomic_add_kernel,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"in_ptr1": t,
|
||||
"out_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
},
|
||||
["out_ptr"],
|
||||
],
|
||||
[
|
||||
add_4_times_kernel,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"in_ptr1": t,
|
||||
"out_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
},
|
||||
# TODO(oulgen): For loops not implemented yet
|
||||
["in_ptr0", "in_ptr1", "out_ptr"],
|
||||
],
|
||||
[
|
||||
cond_op_kernel,
|
||||
{
|
||||
"in_ptr0": t,
|
||||
"in_ptr1": t,
|
||||
"out_ptr": t,
|
||||
"n_elements": 4,
|
||||
"BLOCK_SIZE": 4,
|
||||
},
|
||||
# TODO(oulgen): Dynamic control flow is not implemented yet
|
||||
["in_ptr0", "in_ptr1", "out_ptr"],
|
||||
],
|
||||
]
|
||||
|
||||
for kernel, inputs, outputs in tests:
|
||||
self.assertListEqual(filter_non_mutated(kernel, inputs), outputs)
|
||||
self.assertListEqual(
|
||||
identify_mutated_tensors(kernel, inputs),
|
||||
outputs,
|
||||
msg=f"while testing {kernel.fn.__name__}",
|
||||
)
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(KernelTests)
|
||||
|
||||
no_opt_test_class = make_test_cls_with_patches(
|
||||
KernelTests,
|
||||
"NoOptimization",
|
||||
"_no_optimizations",
|
||||
(config, "optimize_user_defined_triton_kernels", False),
|
||||
)
|
||||
|
||||
globals()[no_opt_test_class.__name__] = no_opt_test_class
|
||||
no_opt_test_class.__module__ = __name__
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -326,6 +326,9 @@ capture_autograd_function = True
|
||||
# enable/disable dynamo tracing for `torch.func` transforms
|
||||
capture_func_transforms = False
|
||||
|
||||
# enable/disable user-defined triton kernel optimizations
|
||||
optimize_user_defined_triton_kernels = True
|
||||
|
||||
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
|
||||
log_compilation_metrics = True
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import ast
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import Tensor
|
||||
@ -15,6 +16,8 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
track_tensor_tree,
|
||||
)
|
||||
|
||||
log = logging.getLogger("torch._dynamo")
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Kernel Side Table
|
||||
@ -60,91 +63,278 @@ kernel_side_table = KernelSideTable()
|
||||
# Mutation Tracker
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Param:
|
||||
idx: int
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Intermediate:
|
||||
idx: int
|
||||
|
||||
def fake(self):
|
||||
return self.idx < 0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MutationInfo:
|
||||
mutated: bool = False
|
||||
used_in_unknown: bool = False
|
||||
class Op:
|
||||
name: str
|
||||
args: List[Union[Param, Intermediate]]
|
||||
|
||||
|
||||
# Super basic mutation tracking pass that tracks which inputs are used in stores
|
||||
# It bails if any of the inputs are used in non tl.load/tl.store positions.
|
||||
# This pass will miss simple things like
|
||||
# a = in_ptr
|
||||
# tl.load(a, ...)
|
||||
# since it does not do any contextual analysis. This means that we might incorrectly
|
||||
# find extra mutations but this is safe as it would only be incorrect to miss
|
||||
# mutations.
|
||||
class MutationTracker(ast.NodeVisitor):
|
||||
ALLOWED_READ_FNS = {
|
||||
"load",
|
||||
"max_constancy",
|
||||
"max_contiguous",
|
||||
"multiple_of",
|
||||
"static_print",
|
||||
"static_assert",
|
||||
"device_print",
|
||||
"device_assert",
|
||||
}
|
||||
|
||||
def __init__(self, infos) -> None:
|
||||
super().__init__()
|
||||
self.infos = infos
|
||||
self.read_depth = 0
|
||||
self.in_store = False
|
||||
|
||||
def visit_Name(self, node):
|
||||
if node.id not in self.infos:
|
||||
return
|
||||
if self.read_depth:
|
||||
pass
|
||||
elif self.in_store:
|
||||
self.infos[node.id].mutated = True
|
||||
else:
|
||||
self.infos[node.id].used_in_unknown = True
|
||||
|
||||
def visit_Call(self, node):
|
||||
# TODO(oulgen): Here we assume that there exists a line called
|
||||
# from triton import language as tl. This needs to be checked
|
||||
# as if someones imports xyz as tl then we will incorrectly
|
||||
# assume a mutation but this would be ok as it is only unsafe to
|
||||
# miss a mutation.
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and isinstance(node.func.value, ast.Name)
|
||||
and node.func.value.id == "tl"
|
||||
):
|
||||
if node.func.attr == "store":
|
||||
# Do not allow for store to appear inside a read
|
||||
# tl.load(a if tl.store(b) else z) is not useful
|
||||
# and allowing this would complicate the analysis
|
||||
assert self.read_depth == 0
|
||||
assert self.in_store is False
|
||||
self.in_store = True
|
||||
self.generic_visit(node)
|
||||
self.in_store = False
|
||||
return
|
||||
if node.func.attr in self.ALLOWED_READ_FNS:
|
||||
self.read_depth += 1
|
||||
self.generic_visit(node)
|
||||
self.read_depth -= 1
|
||||
return
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def filter_non_mutated(kernel, tensors):
|
||||
def generate_ttir(kernel, kwargs):
|
||||
"""
|
||||
Uses Triton's internal code generation to create TTIR
|
||||
"""
|
||||
import triton
|
||||
from triton.compiler.compiler import ASTSource
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
from triton.runtime.jit import JITFunction
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
if isinstance(kernel, Autotuner):
|
||||
if len(kernel.configs) > 0:
|
||||
# If we are autotuning, then it doesn't matter which version gets
|
||||
# picked for tracing purposes, so lets pick the first one
|
||||
kwargs = {**kwargs, **kernel.configs[0].kwargs}
|
||||
kernel = kernel.fn
|
||||
|
||||
infos = {name: MutationInfo() for name in tensors}
|
||||
tracker = MutationTracker(infos)
|
||||
tracker.visit(kernel.parse())
|
||||
assert isinstance(kernel, JITFunction)
|
||||
|
||||
if len(kwargs) != len(kernel.arg_names):
|
||||
raise Exception("Incorrect number of arguments passed to kernel")
|
||||
|
||||
# Drop the keys
|
||||
# Replace all SymExprs with a regular value for TTIR generation
|
||||
# Replace all FakeTensor with real tensors
|
||||
# These replacements are needed for triton's type, key and config functions
|
||||
args: List[Any] = []
|
||||
for name in kernel.arg_names:
|
||||
a = kwargs[name]
|
||||
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
args.append(2)
|
||||
elif isinstance(a, FakeTensor):
|
||||
args.append(torch.empty(2, dtype=a.dtype))
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
tensor_param_locs = [i for i, arg in enumerate(args) if isinstance(arg, Tensor)]
|
||||
specialization = kernel._get_config(*args)
|
||||
constants = {i: arg for i, arg in enumerate(args) if not isinstance(arg, Tensor)}
|
||||
|
||||
# Build kernel signature -- doesn't include constexpr arguments.
|
||||
signature = {
|
||||
i: kernel._type_of(kernel._key_of(arg))
|
||||
for i, arg in enumerate(args)
|
||||
if i not in kernel.constexprs
|
||||
}
|
||||
|
||||
context = triton._C.libtriton.ir.context()
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
backend = triton.compiler.compiler.make_backend(target)
|
||||
options = backend.parse_options(dict())
|
||||
triton._C.libtriton.ir.load_dialects(context)
|
||||
backend.load_dialects(context)
|
||||
|
||||
src = ASTSource(kernel, signature, constants, specialization)
|
||||
ttir_module = src.make_ir(options, context)
|
||||
return str(ttir_module), tensor_param_locs
|
||||
|
||||
|
||||
def parse_ttir(ttir, kwargs):
|
||||
"""
|
||||
Given a Triton emitted TTIR text, this function lexes and parses the
|
||||
code using a minimal grammar defined inside. During the lexing/parsing,
|
||||
we drop any constant value and type information as they are not
|
||||
necessary to us.
|
||||
Being able to choose what we need makes this not a general purpose TTIR
|
||||
parser which further makes parsing much simpler.
|
||||
"""
|
||||
# TODO(oulgen):
|
||||
# - Support multiple functions
|
||||
# - Support parsing of conditionals
|
||||
# - Support parsing for/while loops
|
||||
# - Support ops with multiple return value (e.g. %4:2 = "tt.reduce")
|
||||
|
||||
if ttir.count("tt.func") != 1:
|
||||
log.debug("Multiple functions in TTIR")
|
||||
return None
|
||||
|
||||
try:
|
||||
import lark
|
||||
from lark import Lark, Transformer, v_args
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn(
|
||||
"Using slow path for user-defined Triton kernels. `pip install lark` to fix this."
|
||||
)
|
||||
raise
|
||||
|
||||
ops: Dict[Intermediate, Op] = {}
|
||||
next_fake_intermediate = 0
|
||||
|
||||
# Ops looks like one of the following forms:
|
||||
#
|
||||
# %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>
|
||||
# tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32>
|
||||
# %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950
|
||||
grammar = """
|
||||
start: (module_block | loc_line)+
|
||||
|
||||
loc_line: "#loc" /.+/ NEWLINE
|
||||
|
||||
module_block: "module" "{" decl_block "}" LOC
|
||||
|
||||
decl_block: /.+/ NEWLINE op+ "}" LOC
|
||||
|
||||
op: "tt.return" LOC
|
||||
| assign_lhs "=" OP_NAME args rest -> process_op
|
||||
| OP_NAME args rest -> process_op_no_ret
|
||||
|
||||
?rest: (":" | "{" | "\\"" | "->" | "<") /.+/ NEWLINE
|
||||
|
||||
args: | "("? arg ("," arg)* ")"?
|
||||
|
||||
?arg: INTERMEDIATE | CONSTANT | PARAM | "[" arg "]"
|
||||
|
||||
?assign_lhs: INTERMEDIATE | CONSTANT
|
||||
|
||||
PARAM.5: "%arg" DIGIT+
|
||||
INTERMEDIATE.4: "%" DIGIT+
|
||||
NAME: (LETTER | DIGIT | "_")+
|
||||
CONSTANT: "%"? NAME+ ("<" DIGIT+ ">")?
|
||||
|
||||
OP_NAME: "\\""? NAME "." NAME "\\""?
|
||||
|
||||
LOC: "loc(#loc" DIGIT* ")"
|
||||
|
||||
%import common.LETTER
|
||||
%import common.DIGIT
|
||||
%import common.WS
|
||||
%import common.NEWLINE
|
||||
%import common.ESCAPED_STRING
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
def convert(token):
|
||||
if isinstance(token, lark.tree.Tree):
|
||||
return [convert(a) for a in token.children]
|
||||
if token is None or (
|
||||
isinstance(token, lark.lexer.Token) and token.type == "CONSTANT"
|
||||
):
|
||||
nonlocal next_fake_intermediate
|
||||
next_fake_intermediate -= 1
|
||||
return Intermediate(next_fake_intermediate)
|
||||
|
||||
assert isinstance(token, lark.lexer.Token)
|
||||
|
||||
if token.type == "INTERMEDIATE":
|
||||
return Intermediate(int(token.value[len("%") :]))
|
||||
if token.type == "PARAM":
|
||||
return Param(int(token.value[len("%arg") :]))
|
||||
|
||||
raise AssertionError(f"{type(token.type)} => {token.value} invalid")
|
||||
|
||||
# In alternative representation, function names are quoted.
|
||||
# It should be possible to move this into the grammar alltogether.
|
||||
def convert_name(token):
|
||||
s = token.value
|
||||
if len(s) > 2 and s[0] == '"' and s[-1] == '"':
|
||||
return s[1:-1]
|
||||
return s
|
||||
|
||||
@v_args(inline=True)
|
||||
class CalculateOps(Transformer):
|
||||
def process_op_no_ret(self, *args):
|
||||
self.process_op(None, *args)
|
||||
|
||||
def process_op(self, ret, name, args, *rest):
|
||||
ops[convert(ret)] = Op(convert_name(name), convert(args))
|
||||
|
||||
parser = Lark(grammar, parser="lalr", transformer=CalculateOps())
|
||||
parser.parse(ttir)
|
||||
return ops
|
||||
|
||||
|
||||
def analyze_kernel_mutations(ops, kwargs, tensor_param_locs):
|
||||
"""
|
||||
Analyzes the graph to detect all sinks from a predefined list of sinks
|
||||
by using triton's MemWrite trait list. NOTE: What if triton exposed this?
|
||||
From each sink, it traverses the CFG backwards to identify all the input
|
||||
pointers that are mutated
|
||||
"""
|
||||
# Name of mutation op to mutated parameter indices
|
||||
# List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
|
||||
# All the OPs that have MemWrite trait.
|
||||
# What if Triton exposed this?
|
||||
MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]}
|
||||
# Ops that we want to bail out on
|
||||
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
|
||||
|
||||
stack: List[Union[Param, Intermediate]] = []
|
||||
visited = set()
|
||||
for op in ops.values():
|
||||
if op.name in UNKNOWN_OPS:
|
||||
raise Exception(
|
||||
f"ttir analysis hit an op we do not know how to analyze: {op.name}"
|
||||
)
|
||||
|
||||
for idx in MUTATION_OPS.get(op.name, []):
|
||||
stack.append(op.args[idx])
|
||||
|
||||
# The following is an iterative DFS algorithm
|
||||
mutated = [False] * len(kwargs)
|
||||
while len(stack):
|
||||
arg = stack.pop()
|
||||
if arg in visited:
|
||||
continue
|
||||
else:
|
||||
visited.add(arg)
|
||||
|
||||
if isinstance(arg, Param):
|
||||
mutated[tensor_param_locs[arg.idx]] = True
|
||||
elif isinstance(arg, Intermediate) and not arg.fake():
|
||||
stack.extend(ops[arg].args)
|
||||
|
||||
return [
|
||||
name for name, info in infos.items() if info.mutated or info.used_in_unknown
|
||||
key
|
||||
for i, (key, value) in enumerate(kwargs.items())
|
||||
if isinstance(value, Tensor) and mutated[i]
|
||||
]
|
||||
|
||||
|
||||
def identify_mutated_tensors(kernel, kwargs):
|
||||
"""
|
||||
Given a triton kernel and the arguments for this kernel, this function
|
||||
1) Retrieves the TTIR converted version of the kernel from Triton's API.
|
||||
2) Parses the TTIR and creates a control flow graph
|
||||
3) Analyzes the graph to detect all input tensor mutations
|
||||
"""
|
||||
|
||||
try:
|
||||
from torch._dynamo import config
|
||||
|
||||
if not config.optimize_user_defined_triton_kernels:
|
||||
raise Exception("optimize_user_defined_triton_kernels is False")
|
||||
|
||||
ttir, tensor_param_locs = generate_ttir(kernel, kwargs)
|
||||
ops = parse_ttir(ttir, kwargs)
|
||||
return analyze_kernel_mutations(ops, kwargs, tensor_param_locs)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
log.debug(
|
||||
"Encountered an exception in identify_mutated_tensors, assuming every input is mutated"
|
||||
)
|
||||
log.debug(
|
||||
"".join(
|
||||
traceback.TracebackException.from_exception(e).format() # noqa: G001
|
||||
)
|
||||
)
|
||||
return [key for key, value in kwargs.items() if isinstance(value, Tensor)]
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Triton Kernel Wrappers
|
||||
|
||||
@ -226,15 +416,12 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
|
||||
@triton_kernel_wrapper_mutation.py_functionalize_impl
|
||||
def triton_kernel_wrapper_mutation_functionalize(ctx, kernel_idx, grid, kwargs):
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
tensors_to_clone = [
|
||||
key for key, value in unwrapped_kwargs.items() if isinstance(value, Tensor)
|
||||
]
|
||||
kernel = kernel_side_table.get_kernel(kernel_idx)
|
||||
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
|
||||
# other, and one gets mutated in kernel, and later another gets mutated,
|
||||
# they are no longer equal. Fix this by graph breaking on this condition
|
||||
# earlier in dynamo.
|
||||
tensors_to_clone = filter_non_mutated(kernel, tensors_to_clone)
|
||||
tensors_to_clone = identify_mutated_tensors(kernel, unwrapped_kwargs)
|
||||
with ctx.redispatch_to_next():
|
||||
unwrapped_outputs = triton_kernel_wrapper_functional(
|
||||
kernel_idx=kernel_idx,
|
||||
|
@ -4,6 +4,17 @@ import unittest
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
||||
def has_lark():
|
||||
try:
|
||||
import lark # noqa: F401
|
||||
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
requires_lark = unittest.skipUnless(has_lark(), "requires lark")
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
if HAS_CUDA:
|
||||
@ -28,6 +39,23 @@ if HAS_CUDA:
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_out_of_order(
|
||||
in_ptr0,
|
||||
n_elements,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_with_optional_param(
|
||||
in_ptr0,
|
||||
@ -162,6 +190,8 @@ if HAS_CUDA:
|
||||
mask = offsets < n_elements
|
||||
if ACTIVATION == "mul2_inplace_kernel":
|
||||
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
||||
elif ACTIVATION == "add_kernel":
|
||||
add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
tl.store(out_ptr + offsets, x, mask=mask)
|
||||
|
||||
@ -184,3 +214,143 @@ if HAS_CUDA:
|
||||
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
|
||||
src = tl.load(in_ptr + src_offsets)
|
||||
tl.store(out_ptr + dst_offsets, src * 2.0)
|
||||
|
||||
@triton.jit
|
||||
def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.load(Y + tl.arange(0, BLOCK))
|
||||
s = tl.full([BLOCK], n, tl.int32)
|
||||
z = tl.inline_asm_elementwise(
|
||||
"shf.l.wrap.b32 $0, $1, $2, $3;",
|
||||
"=r,r, r, r",
|
||||
[x, y, s],
|
||||
dtype=tl.int32,
|
||||
is_pure=True,
|
||||
pack=1,
|
||||
)
|
||||
tl.store(Z + tl.arange(0, BLOCK), z)
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_with_block_ptr(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
x = tl.load(
|
||||
tl.make_block_ptr(
|
||||
base=x_ptr,
|
||||
shape=[n_elements],
|
||||
strides=[1],
|
||||
offsets=[block_start],
|
||||
block_shape=[BLOCK_SIZE],
|
||||
order=[0],
|
||||
),
|
||||
boundary_check=[0],
|
||||
)
|
||||
y = tl.load(
|
||||
tl.make_block_ptr(
|
||||
base=y_ptr,
|
||||
shape=[n_elements],
|
||||
strides=[1],
|
||||
offsets=[block_start],
|
||||
block_shape=[BLOCK_SIZE],
|
||||
order=[0],
|
||||
),
|
||||
boundary_check=[0],
|
||||
)
|
||||
output = x + y
|
||||
tl.store(
|
||||
tl.make_block_ptr(
|
||||
base=output_ptr,
|
||||
shape=[n_elements],
|
||||
strides=[1],
|
||||
offsets=[block_start],
|
||||
block_shape=[BLOCK_SIZE],
|
||||
order=[0],
|
||||
),
|
||||
output,
|
||||
boundary_check=[0],
|
||||
)
|
||||
|
||||
from triton.language import load, store
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_with_import(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = load(in_ptr0 + offsets, mask=mask)
|
||||
y = load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def cond_op_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
if tl.program_id(0) == 0:
|
||||
output = x + y
|
||||
else:
|
||||
output = x * y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def atomic_add_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.atomic_add(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def add_4_times_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
for i in range(2):
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
i = 2
|
||||
while i > 0:
|
||||
i -= 1
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
Reference in New Issue
Block a user