mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Inplacing with Donated Buffer (#140113)
Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions. Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible. [Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee)   Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
3ef031909f
commit
eecc8e362c
@ -5198,6 +5198,31 @@ class CommonTemplate:
|
||||
if self.device != "cpu":
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_matmul_layer_norm(self):
|
||||
batch_size = 32
|
||||
seq_length = 50
|
||||
hidden_size = 256
|
||||
|
||||
inp = torch.randn(
|
||||
batch_size,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
requires_grad=True,
|
||||
device=self.device,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_size, hidden_size, requires_grad=True, device=self.device
|
||||
)
|
||||
|
||||
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
|
||||
|
||||
def foo(inp, weight):
|
||||
matmul_output = inp @ weight
|
||||
final_output = layer_norm(matmul_output)
|
||||
return final_output
|
||||
|
||||
self.common(foo, (inp, weight), check_lowp=False)
|
||||
|
||||
def test_transpose_add(self):
|
||||
def fn(a, b):
|
||||
return a.t() + b
|
||||
@ -12855,6 +12880,43 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0)
|
||||
self.assertEqual(fn_opt(*inps), fn(*inps))
|
||||
|
||||
def test_donated_buffer_inplace(self):
|
||||
batch_size = 32
|
||||
seq_length = 50
|
||||
hidden_size = 256
|
||||
|
||||
inp = torch.randn(
|
||||
batch_size,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
requires_grad=True,
|
||||
device=self.device,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_size, hidden_size, requires_grad=True, device=self.device
|
||||
)
|
||||
|
||||
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
|
||||
|
||||
def fn(inp, weight):
|
||||
matmul_output = inp @ weight
|
||||
final_output = layer_norm(matmul_output)
|
||||
return final_output
|
||||
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
def wrapper(inp, weight):
|
||||
return fn_opt(inp, weight).sum().backward()
|
||||
|
||||
_, code = run_and_get_code(wrapper, inp, weight)
|
||||
|
||||
if config.cpp_wrapper:
|
||||
# when using cpp_wrapper, backward triton code is in code[2]
|
||||
self.assertTrue("in_out_ptr" in code[2])
|
||||
else:
|
||||
# when not using cpp_wrapper, backward triton code is in code[1]
|
||||
self.assertTrue("in_out_ptr" in code[1])
|
||||
|
||||
class RNNTest(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
|
@ -2120,7 +2120,11 @@ class PythonWrapperCodegen(CodeGen):
|
||||
def codegen_allocation(self, buffer: ir.Buffer):
|
||||
name = buffer.get_name()
|
||||
|
||||
if name in V.graph.removed_buffers or name in self.allocated:
|
||||
if (
|
||||
name in V.graph.removed_buffers
|
||||
or name in self.allocated
|
||||
or isinstance(buffer, ir.DonatedBuffer)
|
||||
):
|
||||
return
|
||||
self.allocated.add(name)
|
||||
if isinstance(
|
||||
@ -2174,7 +2178,12 @@ class PythonWrapperCodegen(CodeGen):
|
||||
name = input_buffer.get_name()
|
||||
return not (
|
||||
name in V.graph.removed_buffers
|
||||
or name in V.graph.graph_inputs
|
||||
or (
|
||||
name in V.graph.graph_inputs
|
||||
and not isinstance(
|
||||
V.graph.graph_inputs_original[name], ir.DonatedBuffer
|
||||
)
|
||||
)
|
||||
or name in V.graph.constants
|
||||
or name in V.graph.torchbind_constants
|
||||
or name in V.graph.never_reuse_buffers
|
||||
|
@ -832,6 +832,20 @@ class CUDAGraphNode:
|
||||
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
|
||||
]
|
||||
|
||||
# (depth, offset) of live tensors which are alias of previous graph outputs
|
||||
self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [
|
||||
(
|
||||
self._is_alias_of_live_recorded_tensor(t)
|
||||
if isinstance(t, torch.Tensor)
|
||||
else None
|
||||
)
|
||||
for t in inputs
|
||||
]
|
||||
|
||||
# when replay, preserve the liveness of an input if it AliasesPriorGraphOutput
|
||||
# and also aliases an output of the current CUDAGraphNode
|
||||
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
|
||||
|
||||
self.static_input_idxs: List[int] = list(
|
||||
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
|
||||
)
|
||||
@ -1038,11 +1052,11 @@ class CUDAGraphNode:
|
||||
self.check_static_inputs_are_stable(new_inputs)
|
||||
|
||||
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
|
||||
new_inputs.clear()
|
||||
|
||||
self.run_graph()
|
||||
|
||||
outputs = self.reconstruct_outputs()
|
||||
new_inputs.clear()
|
||||
|
||||
if config.triton.fast_path_cudagraph_asserts:
|
||||
self.debug_check_invariants_after_invocation()
|
||||
@ -1261,6 +1275,12 @@ class CUDAGraphNode:
|
||||
path_ref = self._is_alias_of_live_recorded_tensor(o)
|
||||
if path_ref is not None:
|
||||
self._mark_prior_graph_output_as_aliased(path_ref)
|
||||
|
||||
for idx, inp_path_ref in enumerate(
|
||||
self.live_cudagraph_managed_path_refs
|
||||
):
|
||||
if path_ref == inp_path_ref:
|
||||
self.preserved_aliased_inputs[idx] = True
|
||||
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
|
||||
continue
|
||||
|
||||
@ -1667,7 +1687,8 @@ class CUDAGraphNode:
|
||||
# this invocation. it is too late to check after we've replayed the graph,
|
||||
# because we would have already written over their memory.
|
||||
for idx in self.cudagraph_managed_idxs:
|
||||
inputs[idx] = None # type: ignore[call-overload]
|
||||
if not self.preserved_aliased_inputs[idx]:
|
||||
inputs[idx] = None # type: ignore[call-overload]
|
||||
|
||||
torch._check(
|
||||
self._check_liveness(
|
||||
|
@ -74,6 +74,7 @@ from .exc import (
|
||||
)
|
||||
from .ir import (
|
||||
Constant,
|
||||
DonatedBuffer,
|
||||
FixedLayout,
|
||||
get_device_type,
|
||||
InputBuffer,
|
||||
@ -103,6 +104,7 @@ from .utils import (
|
||||
convert_shape_to_inductor,
|
||||
gather_origins,
|
||||
get_cloned_parameter_buffer_name,
|
||||
get_donated_idxs,
|
||||
get_sympy_Expr_dtype,
|
||||
is_same_tensor,
|
||||
maybe_get_suppress_shape_guards_ctx,
|
||||
@ -486,6 +488,11 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# state used by for Kernel.workspace
|
||||
self.workspace_id = itertools.count()
|
||||
|
||||
# track the current placeholder index that we are processing
|
||||
self.placeholder_idx = -1
|
||||
|
||||
self.bw_donated_idxs = get_donated_idxs()
|
||||
|
||||
def has_feature(
|
||||
self,
|
||||
device: Union[torch._inductor.ir.IRNode, device, None],
|
||||
@ -963,6 +970,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
def placeholder(
|
||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||
) -> Union[Expr, TensorBox, None]:
|
||||
self.placeholder_idx += 1
|
||||
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||
target = self.qualify_name(target)
|
||||
if isinstance(example, SymTypes):
|
||||
@ -993,13 +1001,27 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
sizes, strides = self.static_sizes_strides(example)
|
||||
else:
|
||||
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
|
||||
# TODO(jansel): handle input aliasing
|
||||
tensor = TensorBox.create(
|
||||
InputBuffer(
|
||||
name=target,
|
||||
layout=FixedLayout(example.device, example.dtype, sizes, strides),
|
||||
|
||||
if (
|
||||
self.is_backward
|
||||
and self.bw_donated_idxs
|
||||
and self.placeholder_idx in self.bw_donated_idxs
|
||||
):
|
||||
tensor = TensorBox.create(
|
||||
DonatedBuffer(
|
||||
name=target,
|
||||
layout=FixedLayout(example.device, example.dtype, sizes, strides),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO(jansel): handle input aliasing
|
||||
tensor = TensorBox.create(
|
||||
InputBuffer(
|
||||
name=target,
|
||||
layout=FixedLayout(example.device, example.dtype, sizes, strides),
|
||||
)
|
||||
)
|
||||
|
||||
self.graph_inputs[target] = tensor
|
||||
self.graph_input_names.append(target)
|
||||
self.graph_inputs_original[target] = tensor.data.data
|
||||
|
@ -3832,6 +3832,16 @@ class InputBuffer(Buffer):
|
||||
return 1
|
||||
|
||||
|
||||
class DonatedBuffer(InputBuffer):
|
||||
"""
|
||||
Represents a donated buffer which is a saved tensor that is not alias to any
|
||||
fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
|
||||
reuse the input tensor memory during backward since it might be used in another
|
||||
function. However, donated buffer can be inplace reused during backward
|
||||
to save memory.
|
||||
"""
|
||||
|
||||
|
||||
class ConstantBuffer(InputBuffer):
|
||||
override_device: Optional[torch.device] = None
|
||||
|
||||
|
@ -125,10 +125,16 @@ class SchedulerBuffer:
|
||||
hasattr(V.kernel, "args")
|
||||
and self.get_name() in V.kernel.inplace_update_buffers
|
||||
):
|
||||
input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
|
||||
input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
|
||||
if input_buffer_name in self.scheduler.name_to_donated_buffer:
|
||||
input_buffer = self.scheduler.name_to_donated_buffer[
|
||||
input_buffer_name
|
||||
].node
|
||||
else:
|
||||
input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
|
||||
V.graph.wrapper_code.codegen_inplace_reuse(
|
||||
self.scheduler.name_to_buf[
|
||||
V.kernel.inplace_update_buffers[self.get_name()]
|
||||
].node,
|
||||
input_buffer,
|
||||
self.node,
|
||||
)
|
||||
else:
|
||||
@ -163,6 +169,11 @@ class SchedulerBuffer:
|
||||
return self.node.get_mutation_names()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class BaseSchedulerNode:
|
||||
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
|
||||
read_writes: dependencies.ReadWrites
|
||||
@ -442,9 +453,12 @@ class BaseSchedulerNode:
|
||||
continue
|
||||
|
||||
for read in self.read_writes.reads:
|
||||
input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get(
|
||||
read.name
|
||||
)
|
||||
input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
|
||||
if read.name in self.scheduler.name_to_donated_buffer:
|
||||
input_buf = self.scheduler.name_to_donated_buffer[read.name]
|
||||
else:
|
||||
input_buf = self.scheduler.name_to_buf.get(read.name)
|
||||
|
||||
if (
|
||||
input_buf
|
||||
and V.graph.wrapper_code.can_reuse(input_buf, self)
|
||||
@ -470,7 +484,8 @@ class BaseSchedulerNode:
|
||||
),
|
||||
)
|
||||
and not (
|
||||
isinstance(
|
||||
input_buf.defining_op
|
||||
and isinstance(
|
||||
input_buf.defining_op.node,
|
||||
(ir.FallbackKernel, ir.MultiOutput),
|
||||
)
|
||||
@ -1801,6 +1816,9 @@ class Scheduler:
|
||||
for node in self.nodes:
|
||||
node.prune_deps()
|
||||
|
||||
self.name_to_donated_buffer: Dict[
|
||||
str, SchedulerDonatedBuffer
|
||||
] = self.get_donated_buffers()
|
||||
self.name_to_node: Dict[str, BaseSchedulerNode] = {
|
||||
n.get_name(): n for n in self.nodes
|
||||
}
|
||||
@ -1884,6 +1902,17 @@ class Scheduler:
|
||||
}
|
||||
)
|
||||
|
||||
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
|
||||
name_to_donated_buf = {}
|
||||
for name in V.graph.graph_inputs_original:
|
||||
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
|
||||
name_to_donated_buf[name] = SchedulerDonatedBuffer(
|
||||
self,
|
||||
V.graph.graph_inputs_original[name],
|
||||
defining_op=None,
|
||||
)
|
||||
return name_to_donated_buf
|
||||
|
||||
@property
|
||||
def current_device(self) -> Optional[torch.device]:
|
||||
return V.graph.current_device
|
||||
@ -2160,6 +2189,9 @@ class Scheduler:
|
||||
for buf in node.get_outputs():
|
||||
buf.set_users(name_to_users[buf.get_name()].items)
|
||||
|
||||
for name in self.name_to_donated_buffer:
|
||||
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
|
||||
|
||||
def dead_node_elimination(self) -> None:
|
||||
"""
|
||||
Remove any nodes without users
|
||||
|
@ -2200,3 +2200,10 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True):
|
||||
if cls is None:
|
||||
return wrap
|
||||
return wrap(cls)
|
||||
|
||||
|
||||
def get_donated_idxs() -> Optional[List[int]]:
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
if tracing_context is not None and tracing_context.fw_metadata:
|
||||
return tracing_context.fw_metadata.bw_donated_idxs
|
||||
return None
|
||||
|
Reference in New Issue
Block a user