[inductor] Use multiple outputs for flex-attention (#130833)

Resubmit of #129344

This fixes the DCE issue for attention output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130833
Approved by: https://github.com/lezcano
ghstack dependencies: #130832
This commit is contained in:
Peter Bell
2024-07-23 13:10:26 +01:00
committed by PyTorch MergeBot
parent 95c248751b
commit 6415c45da5
2 changed files with 10 additions and 6 deletions

View File

@ -1020,7 +1020,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3))
metrics.reset()
_, code = run_and_get_code(f, q, k, v)
# TODO: attention output is not being DCE'd
fc = FileCheck()
fc.check("triton_tem_fused") # template call
fc.check_not("poi_fused_cos") # No cos pointwise operation
@ -1028,10 +1027,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize
num_accesses = 4 # q, k, v reads, one output.
# TODO: Get rid of this fudge factor
# We need this fudge factor for now, since
# 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow)
# 2. We also write the extraneous logsumexp
num_accesses += 2
# We need this fudge factor for now as we write the extraneous logsumexp
num_accesses += 1
self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
@supported_platform

View File

@ -3849,6 +3849,7 @@ class TritonTemplateBuffer(TemplateBuffer):
super().__init__(layout, inputs, make_kernel_render)
self.debug_extra = debug_extra
self.mutated_inputs = mutated_inputs
self.outputs: List[Buffer] = [self]
if mutated_inputs is not None:
# Ensure that the mutated inputs are only allowed for certain nodes
allowed_set = {
@ -3859,7 +3860,13 @@ class TritonTemplateBuffer(TemplateBuffer):
assert (
current_node in allowed_set
), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
mark_node_as_mutating(self, *mutated_inputs)
device = self.inputs[0].get_device()
self.outputs += [
MutationOutput(NoneLayout(device), buf, self) for buf in mutated_inputs
]
def get_outputs(self) -> List[Buffer]:
return self.outputs
def __str__(self):
out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})"