mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] for UserDefinedTritonKernels don't mark all inputs as mutating (#124425)
Take this example: ``` def _mul2(x): y = torch.empty_like(x) mul2_kernel[(10,)]( in_ptr0=x, out_ptr=y, n_elements=x.numel(), BLOCK_SIZE=1, ) return y def f(x): for _ in range(4): x = _mul2(x) return x + 1 ``` Currently, the codegen will show up like this. Notice, how we allocate 5 buffers of the same size. ``` # Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: [] buf0 = empty_strided_cuda((10, ), (1, ), torch.float32) mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0) ...) # Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: [] buf4 = empty_strided_cuda((10, ), (1, ), torch.float32) mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...) # Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: [] buf8 = empty_strided_cuda((10, ), (1, ), torch.float32) mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf8, (10, ), (1, ), 0) ...) # Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: [] buf12 = empty_strided_cuda((10, ), (1, ), torch.float32) mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf8, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf12, (10, ), (1, ), 0) ...) # Source Nodes: [add], Original ATen: [aten.add] buf16 = empty_strided_cuda((10, ), (1, ), torch.float32) triton_poi_fused_add_0.run(buf12, buf16, 10, grid=grid(10), stream=stream0)...) return (buf16, ) ``` With this PR, we want to see this. Notice, how we only allocate 2 buffers this time. The other 3 buffers are re-used. ``` # Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: [] buf0 = empty_strided_cuda((10, ), (1, ), torch.float32) mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0), ...) del arg0_1 # Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: [] buf2 = empty_strided_cuda((10, ), (1, ), torch.float32) mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf2, (10, ), (1, ), 0) ...) # Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: [] buf4 = buf0; del buf0 # reuse mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf2, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...) # Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: [] buf6 = buf2; del buf2 # reuse mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf6, (10, ), (1, ), 0) ...) del buf4 # Source Nodes: [add], Original ATen: [aten.add] buf8 = buf6; del buf6 # reuse triton_poi_fused_add_0.run(buf8, 10, grid=grid(10), stream=stream0) return (buf8, ) ``` Differential Revision: [D56379307](https://our.internmc.facebook.com/intern/diff/D56379307) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124425 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
0d90d4d613
commit
cbf420b67a
@ -909,7 +909,7 @@ class InplacingTests(TestCase):
|
||||
return output
|
||||
|
||||
inp = (T(10), T(10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """80""")
|
||||
self.assertExpectedInline(count_numel(f, *inp), """60""")
|
||||
|
||||
@requires_cuda
|
||||
@skipIfRocm
|
||||
@ -939,7 +939,7 @@ class InplacingTests(TestCase):
|
||||
return output
|
||||
|
||||
inp = (T(10), T(10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """80""")
|
||||
self.assertExpectedInline(count_numel(f, *inp), """60""")
|
||||
|
||||
@requires_cuda
|
||||
@skipIfRocm
|
||||
|
@ -787,6 +787,42 @@ def forward(self, x_1, output_1):
|
||||
f(x_cloned)
|
||||
out.sum().backward()
|
||||
|
||||
@requires_cuda
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
def test_triton_kernel_inputs_buffer_reuse(self):
|
||||
def _mul2(x):
|
||||
y = torch.empty_like(x)
|
||||
mul2_kernel[(10,)](
|
||||
in_ptr0=x,
|
||||
out_ptr=y,
|
||||
n_elements=x.numel(),
|
||||
BLOCK_SIZE=1,
|
||||
)
|
||||
return y
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
for _ in range(4):
|
||||
# The output of one kernel is the input to the next kernel, but
|
||||
# at some point we should re-use buffers not allocate new ones.
|
||||
x = _mul2(x)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(10, device="cuda", dtype=torch.float32)
|
||||
eager_out = f(x)
|
||||
compiled_out, (code,) = run_and_get_code(torch.compile(f), x)
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
# Check that we're allocating the minimal # of buffers.
|
||||
num_bufs_allocated = code.count(
|
||||
"empty_strided_cuda((10, ), (1, ), torch.float32)"
|
||||
)
|
||||
self.assertEqual(num_bufs_allocated, 2)
|
||||
|
||||
# Check we're re-using buffers if not allocating.
|
||||
num_bufs_reused = code.count("# reuse")
|
||||
self.assertEqual(num_bufs_reused, 3)
|
||||
|
||||
@requires_cuda
|
||||
def test_triton_kernel_matmul_tracking(self):
|
||||
@triton.jit
|
||||
|
@ -111,6 +111,7 @@ def generate_ttir(kernel, kwargs):
|
||||
"""
|
||||
Uses Triton's internal code generation to create TTIR
|
||||
"""
|
||||
import sympy
|
||||
import triton
|
||||
from triton.compiler.compiler import ASTSource
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
@ -132,14 +133,14 @@ def generate_ttir(kernel, kwargs):
|
||||
raise ValueError("Incorrect number of arguments passed to kernel")
|
||||
|
||||
# Replace all SymExprs with a regular value for TTIR generation
|
||||
# Replace all FakeTensor with real tensors
|
||||
# Replace all FakeTensor/TensorBox with real tensors
|
||||
# These replacements are needed for triton's type, key and config functions
|
||||
ordered_args: Dict[str, Any] = {}
|
||||
for name in kernel.arg_names:
|
||||
a = kwargs[name]
|
||||
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)):
|
||||
ordered_args[name] = 2
|
||||
elif isinstance(a, FakeTensor):
|
||||
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
|
||||
with torch._C._DisableTorchDispatch():
|
||||
ordered_args[name] = torch.empty(2, dtype=a.dtype)
|
||||
else:
|
||||
|
@ -4650,6 +4650,9 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
return set()
|
||||
|
||||
def get_mutation_names(self):
|
||||
# NB: Inductor only allows a node to mutate 0 or 1 buffers.
|
||||
# To get around that, we create MutationOutputs which marks their
|
||||
# assigned input as mutable, thus, adhering to Inductor's constraint.
|
||||
return []
|
||||
|
||||
def __init__(self, *, kernel_idx, grid, kernel_args):
|
||||
@ -4679,18 +4682,25 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
self.kernel_idx = kernel_idx
|
||||
self.grid = grid
|
||||
|
||||
kernel, _ = self.get_kernel_and_configs()
|
||||
kernel, configs = self.get_kernel_and_configs()
|
||||
# If we are autotuning, not all arguments will be passed
|
||||
self.ordered_kwargs_for_cpp_kernel = [
|
||||
arg for arg in kernel.arg_names if arg in kernel_args
|
||||
]
|
||||
|
||||
mark_node_as_mutating(
|
||||
self, *[a for a in kernel_args.values() if isinstance(a, TensorBox)]
|
||||
)
|
||||
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
|
||||
|
||||
autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {}
|
||||
self.mutable_args = [
|
||||
kernel_args[key]
|
||||
for key in identify_mutated_tensors(
|
||||
kernel, {**kernel_args, **autotuned_kwargs}
|
||||
)
|
||||
]
|
||||
mark_node_as_mutating(self, *self.mutable_args)
|
||||
|
||||
def get_inputs_that_alias_output(self):
|
||||
return [i.get_name() for i in self.inputs]
|
||||
return [i.get_name() for i in self.mutable_args]
|
||||
|
||||
|
||||
def mark_node_as_mutating(cur_buffer, *mutated_ops: IRNode):
|
||||
|
@ -355,7 +355,11 @@ class BaseSchedulerNode:
|
||||
input_node: Optional[
|
||||
BaseSchedulerNode
|
||||
] = self.scheduler.name_to_node.get(read.name)
|
||||
if input_node and V.graph.wrapper_code.can_reuse(input_node, self):
|
||||
if (
|
||||
input_node
|
||||
and V.graph.wrapper_code.can_reuse(input_node, self)
|
||||
and not isinstance(input_node, NopKernelSchedulerNode)
|
||||
):
|
||||
assert input_node.users is not None
|
||||
remaining_uses = [
|
||||
x
|
||||
|
Reference in New Issue
Block a user