From eecc8e362c2eb192cbe13322af941d09ca647a6b Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 26 Nov 2024 17:19:50 +0000 Subject: [PATCH] [Inductor] Inplacing with Donated Buffer (#140113) Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions. Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible. [Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee) ![image](https://github.com/user-attachments/assets/f19d961f-7973-418e-9de8-5c2a97950478) ![image](https://github.com/user-attachments/assets/df3bd6a9-58b8-4e8a-8397-9e3b1de9adfe) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 62 +++++++++++++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 13 +++++- torch/_inductor/cudagraph_trees.py | 25 +++++++++++- torch/_inductor/graph.py | 34 +++++++++++++--- torch/_inductor/ir.py | 10 +++++ torch/_inductor/scheduler.py | 46 +++++++++++++++++---- torch/_inductor/utils.py | 7 ++++ 7 files changed, 180 insertions(+), 17 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index c2a7731fee0c..24dc8de068e2 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5198,6 +5198,31 @@ class CommonTemplate: if self.device != "cpu": assertGeneratedKernelCountEqual(self, 1) + def test_matmul_layer_norm(self): + batch_size = 32 + seq_length = 50 + hidden_size = 256 + + inp = torch.randn( + batch_size, + seq_length, + hidden_size, + requires_grad=True, + device=self.device, + ) + weight = torch.randn( + hidden_size, hidden_size, requires_grad=True, device=self.device + ) + + layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device) + + def foo(inp, weight): + matmul_output = inp @ weight + final_output = layer_norm(matmul_output) + return final_output + + self.common(foo, (inp, weight), check_lowp=False) + def test_transpose_add(self): def fn(a, b): return a.t() + b @@ -12855,6 +12880,43 @@ if HAS_GPU and not TEST_WITH_ASAN: self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0) self.assertEqual(fn_opt(*inps), fn(*inps)) + def test_donated_buffer_inplace(self): + batch_size = 32 + seq_length = 50 + hidden_size = 256 + + inp = torch.randn( + batch_size, + seq_length, + hidden_size, + requires_grad=True, + device=self.device, + ) + weight = torch.randn( + hidden_size, hidden_size, requires_grad=True, device=self.device + ) + + layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device) + + def fn(inp, weight): + matmul_output = inp @ weight + final_output = layer_norm(matmul_output) + return final_output + + fn_opt = torch.compile(fn) + + def wrapper(inp, weight): + return fn_opt(inp, weight).sum().backward() + + _, code = run_and_get_code(wrapper, inp, weight) + + if config.cpp_wrapper: + # when using cpp_wrapper, backward triton code is in code[2] + self.assertTrue("in_out_ptr" in code[2]) + else: + # when not using cpp_wrapper, backward triton code is in code[1] + self.assertTrue("in_out_ptr" in code[1]) + class RNNTest(TestCase): device_type = GPU_TYPE diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 16d093852ade..15fe504446b0 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2120,7 +2120,11 @@ class PythonWrapperCodegen(CodeGen): def codegen_allocation(self, buffer: ir.Buffer): name = buffer.get_name() - if name in V.graph.removed_buffers or name in self.allocated: + if ( + name in V.graph.removed_buffers + or name in self.allocated + or isinstance(buffer, ir.DonatedBuffer) + ): return self.allocated.add(name) if isinstance( @@ -2174,7 +2178,12 @@ class PythonWrapperCodegen(CodeGen): name = input_buffer.get_name() return not ( name in V.graph.removed_buffers - or name in V.graph.graph_inputs + or ( + name in V.graph.graph_inputs + and not isinstance( + V.graph.graph_inputs_original[name], ir.DonatedBuffer + ) + ) or name in V.graph.constants or name in V.graph.torchbind_constants or name in V.graph.never_reuse_buffers diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index fed24b5d69e1..26d11e767f90 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -832,6 +832,20 @@ class CUDAGraphNode: if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) ] + # (depth, offset) of live tensors which are alias of previous graph outputs + self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [ + ( + self._is_alias_of_live_recorded_tensor(t) + if isinstance(t, torch.Tensor) + else None + ) + for t in inputs + ] + + # when replay, preserve the liveness of an input if it AliasesPriorGraphOutput + # and also aliases an output of the current CUDAGraphNode + self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs) + self.static_input_idxs: List[int] = list( set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) ) @@ -1038,11 +1052,11 @@ class CUDAGraphNode: self.check_static_inputs_are_stable(new_inputs) self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) - new_inputs.clear() self.run_graph() outputs = self.reconstruct_outputs() + new_inputs.clear() if config.triton.fast_path_cudagraph_asserts: self.debug_check_invariants_after_invocation() @@ -1261,6 +1275,12 @@ class CUDAGraphNode: path_ref = self._is_alias_of_live_recorded_tensor(o) if path_ref is not None: self._mark_prior_graph_output_as_aliased(path_ref) + + for idx, inp_path_ref in enumerate( + self.live_cudagraph_managed_path_refs + ): + if path_ref == inp_path_ref: + self.preserved_aliased_inputs[idx] = True self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) continue @@ -1667,7 +1687,8 @@ class CUDAGraphNode: # this invocation. it is too late to check after we've replayed the graph, # because we would have already written over their memory. for idx in self.cudagraph_managed_idxs: - inputs[idx] = None # type: ignore[call-overload] + if not self.preserved_aliased_inputs[idx]: + inputs[idx] = None # type: ignore[call-overload] torch._check( self._check_liveness( diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index d1f3d34eda45..b3d4e2c0c331 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -74,6 +74,7 @@ from .exc import ( ) from .ir import ( Constant, + DonatedBuffer, FixedLayout, get_device_type, InputBuffer, @@ -103,6 +104,7 @@ from .utils import ( convert_shape_to_inductor, gather_origins, get_cloned_parameter_buffer_name, + get_donated_idxs, get_sympy_Expr_dtype, is_same_tensor, maybe_get_suppress_shape_guards_ctx, @@ -486,6 +488,11 @@ class GraphLowering(torch.fx.Interpreter): # state used by for Kernel.workspace self.workspace_id = itertools.count() + # track the current placeholder index that we are processing + self.placeholder_idx = -1 + + self.bw_donated_idxs = get_donated_idxs() + def has_feature( self, device: Union[torch._inductor.ir.IRNode, device, None], @@ -963,6 +970,7 @@ class GraphLowering(torch.fx.Interpreter): def placeholder( self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] ) -> Union[Expr, TensorBox, None]: + self.placeholder_idx += 1 example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] target = self.qualify_name(target) if isinstance(example, SymTypes): @@ -993,13 +1001,27 @@ class GraphLowering(torch.fx.Interpreter): sizes, strides = self.static_sizes_strides(example) else: sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] - # TODO(jansel): handle input aliasing - tensor = TensorBox.create( - InputBuffer( - name=target, - layout=FixedLayout(example.device, example.dtype, sizes, strides), + + if ( + self.is_backward + and self.bw_donated_idxs + and self.placeholder_idx in self.bw_donated_idxs + ): + tensor = TensorBox.create( + DonatedBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) ) - ) + else: + # TODO(jansel): handle input aliasing + tensor = TensorBox.create( + InputBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + self.graph_inputs[target] = tensor self.graph_input_names.append(target) self.graph_inputs_original[target] = tensor.data.data diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 6379d5c99d34..e279f6534bb4 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3832,6 +3832,16 @@ class InputBuffer(Buffer): return 1 +class DonatedBuffer(InputBuffer): + """ + Represents a donated buffer which is a saved tensor that is not alias to any + fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace + reuse the input tensor memory during backward since it might be used in another + function. However, donated buffer can be inplace reused during backward + to save memory. + """ + + class ConstantBuffer(InputBuffer): override_device: Optional[torch.device] = None diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 90b561379a5d..c69ed6574557 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -125,10 +125,16 @@ class SchedulerBuffer: hasattr(V.kernel, "args") and self.get_name() in V.kernel.inplace_update_buffers ): + input_buffer: Union[ir.DonatedBuffer, ir.Buffer] + input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()] + if input_buffer_name in self.scheduler.name_to_donated_buffer: + input_buffer = self.scheduler.name_to_donated_buffer[ + input_buffer_name + ].node + else: + input_buffer = self.scheduler.name_to_buf[input_buffer_name].node V.graph.wrapper_code.codegen_inplace_reuse( - self.scheduler.name_to_buf[ - V.kernel.inplace_update_buffers[self.get_name()] - ].node, + input_buffer, self.node, ) else: @@ -163,6 +169,11 @@ class SchedulerBuffer: return self.node.get_mutation_names() +@dataclasses.dataclass +class SchedulerDonatedBuffer(SchedulerBuffer): + defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment] + + class BaseSchedulerNode: group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] read_writes: dependencies.ReadWrites @@ -442,9 +453,12 @@ class BaseSchedulerNode: continue for read in self.read_writes.reads: - input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get( - read.name - ) + input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]] + if read.name in self.scheduler.name_to_donated_buffer: + input_buf = self.scheduler.name_to_donated_buffer[read.name] + else: + input_buf = self.scheduler.name_to_buf.get(read.name) + if ( input_buf and V.graph.wrapper_code.can_reuse(input_buf, self) @@ -470,7 +484,8 @@ class BaseSchedulerNode: ), ) and not ( - isinstance( + input_buf.defining_op + and isinstance( input_buf.defining_op.node, (ir.FallbackKernel, ir.MultiOutput), ) @@ -1801,6 +1816,9 @@ class Scheduler: for node in self.nodes: node.prune_deps() + self.name_to_donated_buffer: Dict[ + str, SchedulerDonatedBuffer + ] = self.get_donated_buffers() self.name_to_node: Dict[str, BaseSchedulerNode] = { n.get_name(): n for n in self.nodes } @@ -1884,6 +1902,17 @@ class Scheduler: } ) + def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]: + name_to_donated_buf = {} + for name in V.graph.graph_inputs_original: + if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer): + name_to_donated_buf[name] = SchedulerDonatedBuffer( + self, + V.graph.graph_inputs_original[name], + defining_op=None, + ) + return name_to_donated_buf + @property def current_device(self) -> Optional[torch.device]: return V.graph.current_device @@ -2160,6 +2189,9 @@ class Scheduler: for buf in node.get_outputs(): buf.set_users(name_to_users[buf.get_name()].items) + for name in self.name_to_donated_buffer: + self.name_to_donated_buffer[name].set_users(name_to_users[name].items) + def dead_node_elimination(self) -> None: """ Remove any nodes without users diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 95ee3b74dfa4..74ff87f4fa60 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2200,3 +2200,10 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True): if cls is None: return wrap return wrap(cls) + + +def get_donated_idxs() -> Optional[List[int]]: + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and tracing_context.fw_metadata: + return tracing_context.fw_metadata.bw_donated_idxs + return None