[inductor] for UserDefinedTritonKernels don't mark all inputs as mutating (#124425)

Take this example:
```
def _mul2(x):
    y = torch.empty_like(x)
    mul2_kernel[(10,)](
        in_ptr0=x, out_ptr=y,
        n_elements=x.numel(), BLOCK_SIZE=1,
    )
    return y

def f(x):
    for _ in range(4):
        x = _mul2(x)
    return x + 1
```

Currently, the codegen will show up like this. Notice, how we allocate 5 buffers of the same size.
```
# Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: []
buf0 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: []
buf4 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: []
buf8 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf8, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: []
buf12 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf8, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf12, (10, ), (1, ), 0) ...)

# Source Nodes: [add], Original ATen: [aten.add]
buf16 = empty_strided_cuda((10, ), (1, ), torch.float32)
triton_poi_fused_add_0.run(buf12, buf16, 10, grid=grid(10), stream=stream0)...)
return (buf16, )
```

With this PR, we want to see this. Notice, how we only allocate 2 buffers this time. The other 3 buffers are re-used.
```
# Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: []
buf0 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0), ...)
del arg0_1

# Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: []
buf2 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf2, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: []
buf4 = buf0; del buf0  # reuse
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf2, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: []
buf6 = buf2; del buf2  # reuse
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf6, (10, ), (1, ), 0) ...)
del buf4

# Source Nodes: [add], Original ATen: [aten.add]
buf8 = buf6; del buf6  # reuse
triton_poi_fused_add_0.run(buf8, 10, grid=grid(10), stream=stream0)
return (buf8, )
```

Differential Revision: [D56379307](https://our.internmc.facebook.com/intern/diff/D56379307)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124425
Approved by: https://github.com/oulgen
This commit is contained in:
Colin Peppler
2024-04-19 22:58:20 -07:00
committed by PyTorch MergeBot
parent 0d90d4d613
commit cbf420b67a
5 changed files with 62 additions and 11 deletions

View File

@ -909,7 +909,7 @@ class InplacingTests(TestCase):
return output
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """80""")
self.assertExpectedInline(count_numel(f, *inp), """60""")
@requires_cuda
@skipIfRocm
@ -939,7 +939,7 @@ class InplacingTests(TestCase):
return output
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """80""")
self.assertExpectedInline(count_numel(f, *inp), """60""")
@requires_cuda
@skipIfRocm

View File

@ -787,6 +787,42 @@ def forward(self, x_1, output_1):
f(x_cloned)
out.sum().backward()
@requires_cuda
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
def test_triton_kernel_inputs_buffer_reuse(self):
def _mul2(x):
y = torch.empty_like(x)
mul2_kernel[(10,)](
in_ptr0=x,
out_ptr=y,
n_elements=x.numel(),
BLOCK_SIZE=1,
)
return y
@torch.compile
def f(x):
for _ in range(4):
# The output of one kernel is the input to the next kernel, but
# at some point we should re-use buffers not allocate new ones.
x = _mul2(x)
return x + 1
x = torch.randn(10, device="cuda", dtype=torch.float32)
eager_out = f(x)
compiled_out, (code,) = run_and_get_code(torch.compile(f), x)
self.assertEqual(compiled_out, eager_out)
# Check that we're allocating the minimal # of buffers.
num_bufs_allocated = code.count(
"empty_strided_cuda((10, ), (1, ), torch.float32)"
)
self.assertEqual(num_bufs_allocated, 2)
# Check we're re-using buffers if not allocating.
num_bufs_reused = code.count("# reuse")
self.assertEqual(num_bufs_reused, 3)
@requires_cuda
def test_triton_kernel_matmul_tracking(self):
@triton.jit

View File

@ -111,6 +111,7 @@ def generate_ttir(kernel, kwargs):
"""
Uses Triton's internal code generation to create TTIR
"""
import sympy
import triton
from triton.compiler.compiler import ASTSource
from triton.runtime.autotuner import Autotuner
@ -132,14 +133,14 @@ def generate_ttir(kernel, kwargs):
raise ValueError("Incorrect number of arguments passed to kernel")
# Replace all SymExprs with a regular value for TTIR generation
# Replace all FakeTensor with real tensors
# Replace all FakeTensor/TensorBox with real tensors
# These replacements are needed for triton's type, key and config functions
ordered_args: Dict[str, Any] = {}
for name in kernel.arg_names:
a = kwargs[name]
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)):
ordered_args[name] = 2
elif isinstance(a, FakeTensor):
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
with torch._C._DisableTorchDispatch():
ordered_args[name] = torch.empty(2, dtype=a.dtype)
else:

View File

@ -4650,6 +4650,9 @@ class UserDefinedTritonKernel(ExternKernel):
return set()
def get_mutation_names(self):
# NB: Inductor only allows a node to mutate 0 or 1 buffers.
# To get around that, we create MutationOutputs which marks their
# assigned input as mutable, thus, adhering to Inductor's constraint.
return []
def __init__(self, *, kernel_idx, grid, kernel_args):
@ -4679,18 +4682,25 @@ class UserDefinedTritonKernel(ExternKernel):
self.kernel_idx = kernel_idx
self.grid = grid
kernel, _ = self.get_kernel_and_configs()
kernel, configs = self.get_kernel_and_configs()
# If we are autotuning, not all arguments will be passed
self.ordered_kwargs_for_cpp_kernel = [
arg for arg in kernel.arg_names if arg in kernel_args
]
mark_node_as_mutating(
self, *[a for a in kernel_args.values() if isinstance(a, TensorBox)]
)
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {}
self.mutable_args = [
kernel_args[key]
for key in identify_mutated_tensors(
kernel, {**kernel_args, **autotuned_kwargs}
)
]
mark_node_as_mutating(self, *self.mutable_args)
def get_inputs_that_alias_output(self):
return [i.get_name() for i in self.inputs]
return [i.get_name() for i in self.mutable_args]
def mark_node_as_mutating(cur_buffer, *mutated_ops: IRNode):

View File

@ -355,7 +355,11 @@ class BaseSchedulerNode:
input_node: Optional[
BaseSchedulerNode
] = self.scheduler.name_to_node.get(read.name)
if input_node and V.graph.wrapper_code.can_reuse(input_node, self):
if (
input_node
and V.graph.wrapper_code.can_reuse(input_node, self)
and not isinstance(input_node, NopKernelSchedulerNode)
):
assert input_node.users is not None
remaining_uses = [
x