[Inductor] Inplacing with Donated Buffer (#140113)

Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions.

Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible.

[Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee)

![image](https://github.com/user-attachments/assets/f19d961f-7973-418e-9de8-5c2a97950478)
![image](https://github.com/user-attachments/assets/df3bd6a9-58b8-4e8a-8397-9e3b1de9adfe)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng
2024-11-26 17:19:50 +00:00
committed by PyTorch MergeBot
parent 3ef031909f
commit eecc8e362c
7 changed files with 180 additions and 17 deletions

View File

@ -125,10 +125,16 @@ class SchedulerBuffer:
hasattr(V.kernel, "args")
and self.get_name() in V.kernel.inplace_update_buffers
):
input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
if input_buffer_name in self.scheduler.name_to_donated_buffer:
input_buffer = self.scheduler.name_to_donated_buffer[
input_buffer_name
].node
else:
input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
V.graph.wrapper_code.codegen_inplace_reuse(
self.scheduler.name_to_buf[
V.kernel.inplace_update_buffers[self.get_name()]
].node,
input_buffer,
self.node,
)
else:
@ -163,6 +169,11 @@ class SchedulerBuffer:
return self.node.get_mutation_names()
@dataclasses.dataclass
class SchedulerDonatedBuffer(SchedulerBuffer):
defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment]
class BaseSchedulerNode:
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
@ -442,9 +453,12 @@ class BaseSchedulerNode:
continue
for read in self.read_writes.reads:
input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get(
read.name
)
input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
if read.name in self.scheduler.name_to_donated_buffer:
input_buf = self.scheduler.name_to_donated_buffer[read.name]
else:
input_buf = self.scheduler.name_to_buf.get(read.name)
if (
input_buf
and V.graph.wrapper_code.can_reuse(input_buf, self)
@ -470,7 +484,8 @@ class BaseSchedulerNode:
),
)
and not (
isinstance(
input_buf.defining_op
and isinstance(
input_buf.defining_op.node,
(ir.FallbackKernel, ir.MultiOutput),
)
@ -1801,6 +1816,9 @@ class Scheduler:
for node in self.nodes:
node.prune_deps()
self.name_to_donated_buffer: Dict[
str, SchedulerDonatedBuffer
] = self.get_donated_buffers()
self.name_to_node: Dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes
}
@ -1884,6 +1902,17 @@ class Scheduler:
}
)
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
name_to_donated_buf = {}
for name in V.graph.graph_inputs_original:
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
name_to_donated_buf[name] = SchedulerDonatedBuffer(
self,
V.graph.graph_inputs_original[name],
defining_op=None,
)
return name_to_donated_buf
@property
def current_device(self) -> Optional[torch.device]:
return V.graph.current_device
@ -2160,6 +2189,9 @@ class Scheduler:
for buf in node.get_outputs():
buf.set_users(name_to_users[buf.get_name()].items)
for name in self.name_to_donated_buffer:
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
def dead_node_elimination(self) -> None:
"""
Remove any nodes without users