[Inductor] Inplacing with Donated Buffer (#140113)

Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions.

Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible.

[Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee)

![image](https://github.com/user-attachments/assets/f19d961f-7973-418e-9de8-5c2a97950478)
![image](https://github.com/user-attachments/assets/df3bd6a9-58b8-4e8a-8397-9e3b1de9adfe)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng
2024-11-26 17:19:50 +00:00
committed by PyTorch MergeBot
parent 3ef031909f
commit eecc8e362c
7 changed files with 180 additions and 17 deletions

View File

@ -5198,6 +5198,31 @@ class CommonTemplate:
if self.device != "cpu":
assertGeneratedKernelCountEqual(self, 1)
def test_matmul_layer_norm(self):
batch_size = 32
seq_length = 50
hidden_size = 256
inp = torch.randn(
batch_size,
seq_length,
hidden_size,
requires_grad=True,
device=self.device,
)
weight = torch.randn(
hidden_size, hidden_size, requires_grad=True, device=self.device
)
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
def foo(inp, weight):
matmul_output = inp @ weight
final_output = layer_norm(matmul_output)
return final_output
self.common(foo, (inp, weight), check_lowp=False)
def test_transpose_add(self):
def fn(a, b):
return a.t() + b
@ -12855,6 +12880,43 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0)
self.assertEqual(fn_opt(*inps), fn(*inps))
def test_donated_buffer_inplace(self):
batch_size = 32
seq_length = 50
hidden_size = 256
inp = torch.randn(
batch_size,
seq_length,
hidden_size,
requires_grad=True,
device=self.device,
)
weight = torch.randn(
hidden_size, hidden_size, requires_grad=True, device=self.device
)
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
def fn(inp, weight):
matmul_output = inp @ weight
final_output = layer_norm(matmul_output)
return final_output
fn_opt = torch.compile(fn)
def wrapper(inp, weight):
return fn_opt(inp, weight).sum().backward()
_, code = run_and_get_code(wrapper, inp, weight)
if config.cpp_wrapper:
# when using cpp_wrapper, backward triton code is in code[2]
self.assertTrue("in_out_ptr" in code[2])
else:
# when not using cpp_wrapper, backward triton code is in code[1]
self.assertTrue("in_out_ptr" in code[1])
class RNNTest(TestCase):
device_type = GPU_TYPE

View File

@ -2120,7 +2120,11 @@ class PythonWrapperCodegen(CodeGen):
def codegen_allocation(self, buffer: ir.Buffer):
name = buffer.get_name()
if name in V.graph.removed_buffers or name in self.allocated:
if (
name in V.graph.removed_buffers
or name in self.allocated
or isinstance(buffer, ir.DonatedBuffer)
):
return
self.allocated.add(name)
if isinstance(
@ -2174,7 +2178,12 @@ class PythonWrapperCodegen(CodeGen):
name = input_buffer.get_name()
return not (
name in V.graph.removed_buffers
or name in V.graph.graph_inputs
or (
name in V.graph.graph_inputs
and not isinstance(
V.graph.graph_inputs_original[name], ir.DonatedBuffer
)
)
or name in V.graph.constants
or name in V.graph.torchbind_constants
or name in V.graph.never_reuse_buffers

View File

@ -832,6 +832,20 @@ class CUDAGraphNode:
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
]
# (depth, offset) of live tensors which are alias of previous graph outputs
self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [
(
self._is_alias_of_live_recorded_tensor(t)
if isinstance(t, torch.Tensor)
else None
)
for t in inputs
]
# when replay, preserve the liveness of an input if it AliasesPriorGraphOutput
# and also aliases an output of the current CUDAGraphNode
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
self.static_input_idxs: List[int] = list(
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
)
@ -1038,11 +1052,11 @@ class CUDAGraphNode:
self.check_static_inputs_are_stable(new_inputs)
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
new_inputs.clear()
self.run_graph()
outputs = self.reconstruct_outputs()
new_inputs.clear()
if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_after_invocation()
@ -1261,6 +1275,12 @@ class CUDAGraphNode:
path_ref = self._is_alias_of_live_recorded_tensor(o)
if path_ref is not None:
self._mark_prior_graph_output_as_aliased(path_ref)
for idx, inp_path_ref in enumerate(
self.live_cudagraph_managed_path_refs
):
if path_ref == inp_path_ref:
self.preserved_aliased_inputs[idx] = True
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
continue
@ -1667,7 +1687,8 @@ class CUDAGraphNode:
# this invocation. it is too late to check after we've replayed the graph,
# because we would have already written over their memory.
for idx in self.cudagraph_managed_idxs:
inputs[idx] = None # type: ignore[call-overload]
if not self.preserved_aliased_inputs[idx]:
inputs[idx] = None # type: ignore[call-overload]
torch._check(
self._check_liveness(

View File

@ -74,6 +74,7 @@ from .exc import (
)
from .ir import (
Constant,
DonatedBuffer,
FixedLayout,
get_device_type,
InputBuffer,
@ -103,6 +104,7 @@ from .utils import (
convert_shape_to_inductor,
gather_origins,
get_cloned_parameter_buffer_name,
get_donated_idxs,
get_sympy_Expr_dtype,
is_same_tensor,
maybe_get_suppress_shape_guards_ctx,
@ -486,6 +488,11 @@ class GraphLowering(torch.fx.Interpreter):
# state used by for Kernel.workspace
self.workspace_id = itertools.count()
# track the current placeholder index that we are processing
self.placeholder_idx = -1
self.bw_donated_idxs = get_donated_idxs()
def has_feature(
self,
device: Union[torch._inductor.ir.IRNode, device, None],
@ -963,6 +970,7 @@ class GraphLowering(torch.fx.Interpreter):
def placeholder(
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
) -> Union[Expr, TensorBox, None]:
self.placeholder_idx += 1
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
target = self.qualify_name(target)
if isinstance(example, SymTypes):
@ -993,13 +1001,27 @@ class GraphLowering(torch.fx.Interpreter):
sizes, strides = self.static_sizes_strides(example)
else:
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
# TODO(jansel): handle input aliasing
tensor = TensorBox.create(
InputBuffer(
name=target,
layout=FixedLayout(example.device, example.dtype, sizes, strides),
if (
self.is_backward
and self.bw_donated_idxs
and self.placeholder_idx in self.bw_donated_idxs
):
tensor = TensorBox.create(
DonatedBuffer(
name=target,
layout=FixedLayout(example.device, example.dtype, sizes, strides),
)
)
)
else:
# TODO(jansel): handle input aliasing
tensor = TensorBox.create(
InputBuffer(
name=target,
layout=FixedLayout(example.device, example.dtype, sizes, strides),
)
)
self.graph_inputs[target] = tensor
self.graph_input_names.append(target)
self.graph_inputs_original[target] = tensor.data.data

View File

@ -3832,6 +3832,16 @@ class InputBuffer(Buffer):
return 1
class DonatedBuffer(InputBuffer):
"""
Represents a donated buffer which is a saved tensor that is not alias to any
fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
reuse the input tensor memory during backward since it might be used in another
function. However, donated buffer can be inplace reused during backward
to save memory.
"""
class ConstantBuffer(InputBuffer):
override_device: Optional[torch.device] = None

View File

@ -125,10 +125,16 @@ class SchedulerBuffer:
hasattr(V.kernel, "args")
and self.get_name() in V.kernel.inplace_update_buffers
):
input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
if input_buffer_name in self.scheduler.name_to_donated_buffer:
input_buffer = self.scheduler.name_to_donated_buffer[
input_buffer_name
].node
else:
input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
V.graph.wrapper_code.codegen_inplace_reuse(
self.scheduler.name_to_buf[
V.kernel.inplace_update_buffers[self.get_name()]
].node,
input_buffer,
self.node,
)
else:
@ -163,6 +169,11 @@ class SchedulerBuffer:
return self.node.get_mutation_names()
@dataclasses.dataclass
class SchedulerDonatedBuffer(SchedulerBuffer):
defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment]
class BaseSchedulerNode:
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
@ -442,9 +453,12 @@ class BaseSchedulerNode:
continue
for read in self.read_writes.reads:
input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get(
read.name
)
input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
if read.name in self.scheduler.name_to_donated_buffer:
input_buf = self.scheduler.name_to_donated_buffer[read.name]
else:
input_buf = self.scheduler.name_to_buf.get(read.name)
if (
input_buf
and V.graph.wrapper_code.can_reuse(input_buf, self)
@ -470,7 +484,8 @@ class BaseSchedulerNode:
),
)
and not (
isinstance(
input_buf.defining_op
and isinstance(
input_buf.defining_op.node,
(ir.FallbackKernel, ir.MultiOutput),
)
@ -1801,6 +1816,9 @@ class Scheduler:
for node in self.nodes:
node.prune_deps()
self.name_to_donated_buffer: Dict[
str, SchedulerDonatedBuffer
] = self.get_donated_buffers()
self.name_to_node: Dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes
}
@ -1884,6 +1902,17 @@ class Scheduler:
}
)
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
name_to_donated_buf = {}
for name in V.graph.graph_inputs_original:
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
name_to_donated_buf[name] = SchedulerDonatedBuffer(
self,
V.graph.graph_inputs_original[name],
defining_op=None,
)
return name_to_donated_buf
@property
def current_device(self) -> Optional[torch.device]:
return V.graph.current_device
@ -2160,6 +2189,9 @@ class Scheduler:
for buf in node.get_outputs():
buf.set_users(name_to_users[buf.get_name()].items)
for name in self.name_to_donated_buffer:
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
def dead_node_elimination(self) -> None:
"""
Remove any nodes without users

View File

@ -2200,3 +2200,10 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True):
if cls is None:
return wrap
return wrap(cls)
def get_donated_idxs() -> Optional[List[int]]:
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None and tracing_context.fw_metadata:
return tracing_context.fw_metadata.bw_donated_idxs
return None