mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Use multiple outputs for flex-attention (#130833)
Resubmit of #129344 This fixes the DCE issue for attention output Pull Request resolved: https://github.com/pytorch/pytorch/pull/130833 Approved by: https://github.com/lezcano ghstack dependencies: #130832
This commit is contained in:
committed by
PyTorch MergeBot
parent
95c248751b
commit
6415c45da5
@ -1020,7 +1020,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3))
|
||||
metrics.reset()
|
||||
_, code = run_and_get_code(f, q, k, v)
|
||||
# TODO: attention output is not being DCE'd
|
||||
fc = FileCheck()
|
||||
fc.check("triton_tem_fused") # template call
|
||||
fc.check_not("poi_fused_cos") # No cos pointwise operation
|
||||
@ -1028,10 +1027,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize
|
||||
num_accesses = 4 # q, k, v reads, one output.
|
||||
# TODO: Get rid of this fudge factor
|
||||
# We need this fudge factor for now, since
|
||||
# 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow)
|
||||
# 2. We also write the extraneous logsumexp
|
||||
num_accesses += 2
|
||||
# We need this fudge factor for now as we write the extraneous logsumexp
|
||||
num_accesses += 1
|
||||
self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
|
||||
|
||||
@supported_platform
|
||||
|
@ -3849,6 +3849,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
super().__init__(layout, inputs, make_kernel_render)
|
||||
self.debug_extra = debug_extra
|
||||
self.mutated_inputs = mutated_inputs
|
||||
self.outputs: List[Buffer] = [self]
|
||||
if mutated_inputs is not None:
|
||||
# Ensure that the mutated inputs are only allowed for certain nodes
|
||||
allowed_set = {
|
||||
@ -3859,7 +3860,13 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
assert (
|
||||
current_node in allowed_set
|
||||
), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
|
||||
mark_node_as_mutating(self, *mutated_inputs)
|
||||
device = self.inputs[0].get_device()
|
||||
self.outputs += [
|
||||
MutationOutput(NoneLayout(device), buf, self) for buf in mutated_inputs
|
||||
]
|
||||
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
return self.outputs
|
||||
|
||||
def __str__(self):
|
||||
out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})"
|
||||
|
Reference in New Issue
Block a user