[inductor] let inplace-padding support cpp-wrapper (#145325)

Some context: Inplace padding is an optimization to do padding in place. E.g., if a tensor has size [2048, 2047] and stride [2048, 1]. When we need pad one extra element to the end of each row (e.g. during mm padding), we can just reuse the original tensor and do the padding inplace. This saves memory and bandwidth.  One caveat for this optimization is, PyTorch does not allocate 2048 elements for the last row of the original tensor. It only allocate 2047 elements. So assuming the last row having enough space for 2048 elements may be wrong and cause OOB memory access (although I never see this happen maybe due to overallocation in the CUDACachingAllocation, this should better be fixed).

The fix is when we allocate the tensor, instead of doing something like:
```
  buf0 = randn_strided([2048, 2047], [2048, 1])
```
we do some small overallocation
```
  buf0 = randn_strided([2048, 2048], [2048, 1]).as_strided([2048, 2047], [2048, 1])
```

cpp_wrapper needs special handling since memory allocation goes thru different code path to python wrapper.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145325
Approved by: https://github.com/desertfire, https://github.com/jansel
ghstack dependencies: #140249
This commit is contained in:
Shunting Zhang
2025-01-22 12:08:35 -08:00
committed by PyTorch MergeBot
parent f52901a0a7
commit d3f196909d
7 changed files with 118 additions and 23 deletions

View File

@ -1886,6 +1886,13 @@ class PythonWrapperCodegen(CodeGen):
)
for e in buf.get_size()
)
allocation_size = tuple(
V.graph.sizevars.atomically_apply_size_hint(
e,
fallback=config.unbacked_symint_fallback,
)
for e in V.graph.get_allocation_size(buf)
)
stride = tuple(
V.graph.sizevars.atomically_apply_size_hint(
e,
@ -1899,7 +1906,7 @@ class PythonWrapperCodegen(CodeGen):
buf.get_layout().offset,
fallback=config.unbacked_symint_fallback,
)
value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset})"
value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset}, {allocation_size})"
self.kernel_autotune_calls.writeline(f"{buf_name} = {value}")
if isinstance(raw_arg, ir.TMADescriptor):