[Dynamo] Analyze triton kernels via tracing to determine mutations (#117300)

This PR adds TTIR lexing and parsing in order to analyze which of the user defined triton kernel inputs are mutated.

Differential Revision: [D53165999](https://our.internmc.facebook.com/intern/diff/D53165999)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117300
Approved by: https://github.com/jansel
This commit is contained in:
Oguz Ulgen
2024-01-28 11:13:08 -08:00
committed by PyTorch MergeBot
parent 2951bbf0f7
commit 47b5a6b05d
6 changed files with 591 additions and 86 deletions

View File

@ -52,6 +52,11 @@ junitparser==2.1.1
#Pinned versions: 2.1.1
#test that import:
lark==0.12.0
#Description: parser
#Pinned versions: 0.12.0
#test that import:
librosa>=0.6.2 ; python_version < "3.11"
#Description: A python package for music and audio analysis
#Pinned versions: >=0.6.2

View File

@ -18,3 +18,4 @@ fsspec
setuptools ; python_version >= "3.12"
packaging
optree>=0.9.1
lark

View File

@ -7,6 +7,8 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_functional,
@ -893,22 +895,159 @@ def forward(self, x_1, output_1):
class MutationTests(torch._dynamo.test_case.TestCase):
@requires_cuda
@requires_lark
def test_find_mutations(self):
from torch._higher_order_ops.triton_kernel_wrap import filter_non_mutated
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
t = torch.randn(4)
tests = [
[add_kernel, ["in_ptr0", "in_ptr1", "out_ptr"], ["out_ptr"]],
[add_kernel_2d_autotuned, ["in_ptr0", "in_ptr1", "out_ptr"], ["out_ptr"]],
# Cannot remove in_ptr0 since it is used in a external call
[indirection_kernel, ["in_ptr0", "out_ptr"], ["in_ptr0", "out_ptr"]],
[mul2_inplace_kernel, ["ptr"], ["ptr"]],
[
add_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
],
[
add_kernel_out_of_order,
{
"in_ptr0": t,
"n_elements": 4,
"in_ptr1": t,
"out_ptr": t,
"BLOCK_SIZE": 4,
},
["out_ptr"],
],
[
add_kernel_2d_autotuned,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"x_elements": 4,
"y_elements": 4,
},
["out_ptr"],
],
[
indirection_kernel,
{
"in_ptr0": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
"ACTIVATION": "mul2_inplace_kernel",
},
["in_ptr0", "out_ptr"],
],
[
indirection_kernel,
{
"in_ptr0": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
"ACTIVATION": "add_kernel",
},
# TODO(oulgen): Multiple functions is not implemented yet
["in_ptr0", "out_ptr"],
],
[
mul2_inplace_kernel,
{"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4},
["ptr"],
],
# Cant optimize since the kernel contains a tl.inline_asm_elementwise
[
inline_asm_kernel,
{"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4},
["X", "Y", "Z"],
],
[
add_kernel_with_block_ptr,
{
"x_ptr": t,
"y_ptr": t,
"output_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["output_ptr"],
],
[
add_kernel_with_import,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
],
[
atomic_add_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
],
[
add_4_times_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
# TODO(oulgen): For loops not implemented yet
["in_ptr0", "in_ptr1", "out_ptr"],
],
[
cond_op_kernel,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
# TODO(oulgen): Dynamic control flow is not implemented yet
["in_ptr0", "in_ptr1", "out_ptr"],
],
]
for kernel, inputs, outputs in tests:
self.assertListEqual(filter_non_mutated(kernel, inputs), outputs)
self.assertListEqual(
identify_mutated_tensors(kernel, inputs),
outputs,
msg=f"while testing {kernel.fn.__name__}",
)
common_utils.instantiate_parametrized_tests(KernelTests)
no_opt_test_class = make_test_cls_with_patches(
KernelTests,
"NoOptimization",
"_no_optimizations",
(config, "optimize_user_defined_triton_kernels", False),
)
globals()[no_opt_test_class.__name__] = no_opt_test_class
no_opt_test_class.__module__ = __name__
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -326,6 +326,9 @@ capture_autograd_function = True
# enable/disable dynamo tracing for `torch.func` transforms
capture_func_transforms = False
# enable/disable user-defined triton kernel optimizations
optimize_user_defined_triton_kernels = True
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
log_compilation_metrics = True

View File

@ -1,7 +1,8 @@
import ast
import dataclasses
import logging
import threading
from typing import Any, Dict
import warnings
from typing import Any, Dict, List, Union
import torch.utils._pytree as pytree
from torch import Tensor
@ -15,6 +16,8 @@ from torch.fx.experimental.proxy_tensor import (
track_tensor_tree,
)
log = logging.getLogger("torch._dynamo")
###############################################################################
# Kernel Side Table
@ -60,91 +63,278 @@ kernel_side_table = KernelSideTable()
# Mutation Tracker
@dataclasses.dataclass(frozen=True)
class Param:
idx: int
@dataclasses.dataclass(frozen=True)
class Intermediate:
idx: int
def fake(self):
return self.idx < 0
@dataclasses.dataclass
class MutationInfo:
mutated: bool = False
used_in_unknown: bool = False
class Op:
name: str
args: List[Union[Param, Intermediate]]
# Super basic mutation tracking pass that tracks which inputs are used in stores
# It bails if any of the inputs are used in non tl.load/tl.store positions.
# This pass will miss simple things like
# a = in_ptr
# tl.load(a, ...)
# since it does not do any contextual analysis. This means that we might incorrectly
# find extra mutations but this is safe as it would only be incorrect to miss
# mutations.
class MutationTracker(ast.NodeVisitor):
ALLOWED_READ_FNS = {
"load",
"max_constancy",
"max_contiguous",
"multiple_of",
"static_print",
"static_assert",
"device_print",
"device_assert",
}
def __init__(self, infos) -> None:
super().__init__()
self.infos = infos
self.read_depth = 0
self.in_store = False
def visit_Name(self, node):
if node.id not in self.infos:
return
if self.read_depth:
pass
elif self.in_store:
self.infos[node.id].mutated = True
else:
self.infos[node.id].used_in_unknown = True
def visit_Call(self, node):
# TODO(oulgen): Here we assume that there exists a line called
# from triton import language as tl. This needs to be checked
# as if someones imports xyz as tl then we will incorrectly
# assume a mutation but this would be ok as it is only unsafe to
# miss a mutation.
if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "tl"
):
if node.func.attr == "store":
# Do not allow for store to appear inside a read
# tl.load(a if tl.store(b) else z) is not useful
# and allowing this would complicate the analysis
assert self.read_depth == 0
assert self.in_store is False
self.in_store = True
self.generic_visit(node)
self.in_store = False
return
if node.func.attr in self.ALLOWED_READ_FNS:
self.read_depth += 1
self.generic_visit(node)
self.read_depth -= 1
return
self.generic_visit(node)
def filter_non_mutated(kernel, tensors):
def generate_ttir(kernel, kwargs):
"""
Uses Triton's internal code generation to create TTIR
"""
import triton
from triton.compiler.compiler import ASTSource
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
import torch
from torch._subclasses.fake_tensor import FakeTensor
if isinstance(kernel, Autotuner):
if len(kernel.configs) > 0:
# If we are autotuning, then it doesn't matter which version gets
# picked for tracing purposes, so lets pick the first one
kwargs = {**kwargs, **kernel.configs[0].kwargs}
kernel = kernel.fn
infos = {name: MutationInfo() for name in tensors}
tracker = MutationTracker(infos)
tracker.visit(kernel.parse())
assert isinstance(kernel, JITFunction)
if len(kwargs) != len(kernel.arg_names):
raise Exception("Incorrect number of arguments passed to kernel")
# Drop the keys
# Replace all SymExprs with a regular value for TTIR generation
# Replace all FakeTensor with real tensors
# These replacements are needed for triton's type, key and config functions
args: List[Any] = []
for name in kernel.arg_names:
a = kwargs[name]
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):
args.append(2)
elif isinstance(a, FakeTensor):
args.append(torch.empty(2, dtype=a.dtype))
else:
args.append(a)
tensor_param_locs = [i for i, arg in enumerate(args) if isinstance(arg, Tensor)]
specialization = kernel._get_config(*args)
constants = {i: arg for i, arg in enumerate(args) if not isinstance(arg, Tensor)}
# Build kernel signature -- doesn't include constexpr arguments.
signature = {
i: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs
}
context = triton._C.libtriton.ir.context()
target = triton.runtime.driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)
src = ASTSource(kernel, signature, constants, specialization)
ttir_module = src.make_ir(options, context)
return str(ttir_module), tensor_param_locs
def parse_ttir(ttir, kwargs):
"""
Given a Triton emitted TTIR text, this function lexes and parses the
code using a minimal grammar defined inside. During the lexing/parsing,
we drop any constant value and type information as they are not
necessary to us.
Being able to choose what we need makes this not a general purpose TTIR
parser which further makes parsing much simpler.
"""
# TODO(oulgen):
# - Support multiple functions
# - Support parsing of conditionals
# - Support parsing for/while loops
# - Support ops with multiple return value (e.g. %4:2 = "tt.reduce")
if ttir.count("tt.func") != 1:
log.debug("Multiple functions in TTIR")
return None
try:
import lark
from lark import Lark, Transformer, v_args
except ModuleNotFoundError:
warnings.warn(
"Using slow path for user-defined Triton kernels. `pip install lark` to fix this."
)
raise
ops: Dict[Intermediate, Op] = {}
next_fake_intermediate = 0
# Ops looks like one of the following forms:
#
# %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>
# tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32>
# %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950
grammar = """
start: (module_block | loc_line)+
loc_line: "#loc" /.+/ NEWLINE
module_block: "module" "{" decl_block "}" LOC
decl_block: /.+/ NEWLINE op+ "}" LOC
op: "tt.return" LOC
| assign_lhs "=" OP_NAME args rest -> process_op
| OP_NAME args rest -> process_op_no_ret
?rest: (":" | "{" | "\\"" | "->" | "<") /.+/ NEWLINE
args: | "("? arg ("," arg)* ")"?
?arg: INTERMEDIATE | CONSTANT | PARAM | "[" arg "]"
?assign_lhs: INTERMEDIATE | CONSTANT
PARAM.5: "%arg" DIGIT+
INTERMEDIATE.4: "%" DIGIT+
NAME: (LETTER | DIGIT | "_")+
CONSTANT: "%"? NAME+ ("<" DIGIT+ ">")?
OP_NAME: "\\""? NAME "." NAME "\\""?
LOC: "loc(#loc" DIGIT* ")"
%import common.LETTER
%import common.DIGIT
%import common.WS
%import common.NEWLINE
%import common.ESCAPED_STRING
%ignore WS
"""
def convert(token):
if isinstance(token, lark.tree.Tree):
return [convert(a) for a in token.children]
if token is None or (
isinstance(token, lark.lexer.Token) and token.type == "CONSTANT"
):
nonlocal next_fake_intermediate
next_fake_intermediate -= 1
return Intermediate(next_fake_intermediate)
assert isinstance(token, lark.lexer.Token)
if token.type == "INTERMEDIATE":
return Intermediate(int(token.value[len("%") :]))
if token.type == "PARAM":
return Param(int(token.value[len("%arg") :]))
raise AssertionError(f"{type(token.type)} => {token.value} invalid")
# In alternative representation, function names are quoted.
# It should be possible to move this into the grammar alltogether.
def convert_name(token):
s = token.value
if len(s) > 2 and s[0] == '"' and s[-1] == '"':
return s[1:-1]
return s
@v_args(inline=True)
class CalculateOps(Transformer):
def process_op_no_ret(self, *args):
self.process_op(None, *args)
def process_op(self, ret, name, args, *rest):
ops[convert(ret)] = Op(convert_name(name), convert(args))
parser = Lark(grammar, parser="lalr", transformer=CalculateOps())
parser.parse(ttir)
return ops
def analyze_kernel_mutations(ops, kwargs, tensor_param_locs):
"""
Analyzes the graph to detect all sinks from a predefined list of sinks
by using triton's MemWrite trait list. NOTE: What if triton exposed this?
From each sink, it traverses the CFG backwards to identify all the input
pointers that are mutated
"""
# Name of mutation op to mutated parameter indices
# List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
# All the OPs that have MemWrite trait.
# What if Triton exposed this?
MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]}
# Ops that we want to bail out on
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
stack: List[Union[Param, Intermediate]] = []
visited = set()
for op in ops.values():
if op.name in UNKNOWN_OPS:
raise Exception(
f"ttir analysis hit an op we do not know how to analyze: {op.name}"
)
for idx in MUTATION_OPS.get(op.name, []):
stack.append(op.args[idx])
# The following is an iterative DFS algorithm
mutated = [False] * len(kwargs)
while len(stack):
arg = stack.pop()
if arg in visited:
continue
else:
visited.add(arg)
if isinstance(arg, Param):
mutated[tensor_param_locs[arg.idx]] = True
elif isinstance(arg, Intermediate) and not arg.fake():
stack.extend(ops[arg].args)
return [
name for name, info in infos.items() if info.mutated or info.used_in_unknown
key
for i, (key, value) in enumerate(kwargs.items())
if isinstance(value, Tensor) and mutated[i]
]
def identify_mutated_tensors(kernel, kwargs):
"""
Given a triton kernel and the arguments for this kernel, this function
1) Retrieves the TTIR converted version of the kernel from Triton's API.
2) Parses the TTIR and creates a control flow graph
3) Analyzes the graph to detect all input tensor mutations
"""
try:
from torch._dynamo import config
if not config.optimize_user_defined_triton_kernels:
raise Exception("optimize_user_defined_triton_kernels is False")
ttir, tensor_param_locs = generate_ttir(kernel, kwargs)
ops = parse_ttir(ttir, kwargs)
return analyze_kernel_mutations(ops, kwargs, tensor_param_locs)
except Exception as e:
import traceback
log.debug(
"Encountered an exception in identify_mutated_tensors, assuming every input is mutated"
)
log.debug(
"".join(
traceback.TracebackException.from_exception(e).format() # noqa: G001
)
)
return [key for key, value in kwargs.items() if isinstance(value, Tensor)]
###############################################################################
# Triton Kernel Wrappers
@ -226,15 +416,12 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
@triton_kernel_wrapper_mutation.py_functionalize_impl
def triton_kernel_wrapper_mutation_functionalize(ctx, kernel_idx, grid, kwargs):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
tensors_to_clone = [
key for key, value in unwrapped_kwargs.items() if isinstance(value, Tensor)
]
kernel = kernel_side_table.get_kernel(kernel_idx)
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
# other, and one gets mutated in kernel, and later another gets mutated,
# they are no longer equal. Fix this by graph breaking on this condition
# earlier in dynamo.
tensors_to_clone = filter_non_mutated(kernel, tensors_to_clone)
tensors_to_clone = identify_mutated_tensors(kernel, unwrapped_kwargs)
with ctx.redispatch_to_next():
unwrapped_outputs = triton_kernel_wrapper_functional(
kernel_idx=kernel_idx,

View File

@ -4,6 +4,17 @@ import unittest
from torch.testing._internal.inductor_utils import HAS_CUDA
def has_lark():
try:
import lark # noqa: F401
return True
except ModuleNotFoundError:
return False
requires_lark = unittest.skipUnless(has_lark(), "requires lark")
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
if HAS_CUDA:
@ -28,6 +39,23 @@ if HAS_CUDA:
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_out_of_order(
in_ptr0,
n_elements,
in_ptr1,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_optional_param(
in_ptr0,
@ -162,6 +190,8 @@ if HAS_CUDA:
mask = offsets < n_elements
if ACTIVATION == "mul2_inplace_kernel":
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
elif ACTIVATION == "add_kernel":
add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
x = tl.load(in_ptr0 + offsets, mask=mask)
tl.store(out_ptr + offsets, x, mask=mask)
@ -184,3 +214,143 @@ if HAS_CUDA:
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
src = tl.load(in_ptr + src_offsets)
tl.store(out_ptr + dst_offsets, src * 2.0)
@triton.jit
def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.load(Y + tl.arange(0, BLOCK))
s = tl.full([BLOCK], n, tl.int32)
z = tl.inline_asm_elementwise(
"shf.l.wrap.b32 $0, $1, $2, $3;",
"=r,r, r, r",
[x, y, s],
dtype=tl.int32,
is_pure=True,
pack=1,
)
tl.store(Z + tl.arange(0, BLOCK), z)
@triton.jit
def add_kernel_with_block_ptr(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
x = tl.load(
tl.make_block_ptr(
base=x_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
boundary_check=[0],
)
y = tl.load(
tl.make_block_ptr(
base=y_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
boundary_check=[0],
)
output = x + y
tl.store(
tl.make_block_ptr(
base=output_ptr,
shape=[n_elements],
strides=[1],
offsets=[block_start],
block_shape=[BLOCK_SIZE],
order=[0],
),
output,
boundary_check=[0],
)
from triton.language import load, store
@triton.jit
def add_kernel_with_import(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = load(in_ptr0 + offsets, mask=mask)
y = load(in_ptr1 + offsets, mask=mask)
output = x + y
store(out_ptr + offsets, output, mask=mask)
@triton.jit
def cond_op_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
if tl.program_id(0) == 0:
output = x + y
else:
output = x * y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def atomic_add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.atomic_add(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_4_times_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
for i in range(2):
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
i = 2
while i > 0:
i -= 1
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)