mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Inplacing with Donated Buffer (#140113)
Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions. Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible. [Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee)   Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
3ef031909f
commit
eecc8e362c
@ -125,10 +125,16 @@ class SchedulerBuffer:
|
||||
hasattr(V.kernel, "args")
|
||||
and self.get_name() in V.kernel.inplace_update_buffers
|
||||
):
|
||||
input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
|
||||
input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
|
||||
if input_buffer_name in self.scheduler.name_to_donated_buffer:
|
||||
input_buffer = self.scheduler.name_to_donated_buffer[
|
||||
input_buffer_name
|
||||
].node
|
||||
else:
|
||||
input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
|
||||
V.graph.wrapper_code.codegen_inplace_reuse(
|
||||
self.scheduler.name_to_buf[
|
||||
V.kernel.inplace_update_buffers[self.get_name()]
|
||||
].node,
|
||||
input_buffer,
|
||||
self.node,
|
||||
)
|
||||
else:
|
||||
@ -163,6 +169,11 @@ class SchedulerBuffer:
|
||||
return self.node.get_mutation_names()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class BaseSchedulerNode:
|
||||
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
|
||||
read_writes: dependencies.ReadWrites
|
||||
@ -442,9 +453,12 @@ class BaseSchedulerNode:
|
||||
continue
|
||||
|
||||
for read in self.read_writes.reads:
|
||||
input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get(
|
||||
read.name
|
||||
)
|
||||
input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
|
||||
if read.name in self.scheduler.name_to_donated_buffer:
|
||||
input_buf = self.scheduler.name_to_donated_buffer[read.name]
|
||||
else:
|
||||
input_buf = self.scheduler.name_to_buf.get(read.name)
|
||||
|
||||
if (
|
||||
input_buf
|
||||
and V.graph.wrapper_code.can_reuse(input_buf, self)
|
||||
@ -470,7 +484,8 @@ class BaseSchedulerNode:
|
||||
),
|
||||
)
|
||||
and not (
|
||||
isinstance(
|
||||
input_buf.defining_op
|
||||
and isinstance(
|
||||
input_buf.defining_op.node,
|
||||
(ir.FallbackKernel, ir.MultiOutput),
|
||||
)
|
||||
@ -1801,6 +1816,9 @@ class Scheduler:
|
||||
for node in self.nodes:
|
||||
node.prune_deps()
|
||||
|
||||
self.name_to_donated_buffer: Dict[
|
||||
str, SchedulerDonatedBuffer
|
||||
] = self.get_donated_buffers()
|
||||
self.name_to_node: Dict[str, BaseSchedulerNode] = {
|
||||
n.get_name(): n for n in self.nodes
|
||||
}
|
||||
@ -1884,6 +1902,17 @@ class Scheduler:
|
||||
}
|
||||
)
|
||||
|
||||
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
|
||||
name_to_donated_buf = {}
|
||||
for name in V.graph.graph_inputs_original:
|
||||
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
|
||||
name_to_donated_buf[name] = SchedulerDonatedBuffer(
|
||||
self,
|
||||
V.graph.graph_inputs_original[name],
|
||||
defining_op=None,
|
||||
)
|
||||
return name_to_donated_buf
|
||||
|
||||
@property
|
||||
def current_device(self) -> Optional[torch.device]:
|
||||
return V.graph.current_device
|
||||
@ -2160,6 +2189,9 @@ class Scheduler:
|
||||
for buf in node.get_outputs():
|
||||
buf.set_users(name_to_users[buf.get_name()].items)
|
||||
|
||||
for name in self.name_to_donated_buffer:
|
||||
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
|
||||
|
||||
def dead_node_elimination(self) -> None:
|
||||
"""
|
||||
Remove any nodes without users
|
||||
|
||||
Reference in New Issue
Block a user