mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] let inplace-padding support cpp-wrapper (#145325)
Some context: Inplace padding is an optimization to do padding in place. E.g., if a tensor has size [2048, 2047] and stride [2048, 1]. When we need pad one extra element to the end of each row (e.g. during mm padding), we can just reuse the original tensor and do the padding inplace. This saves memory and bandwidth. One caveat for this optimization is, PyTorch does not allocate 2048 elements for the last row of the original tensor. It only allocate 2047 elements. So assuming the last row having enough space for 2048 elements may be wrong and cause OOB memory access (although I never see this happen maybe due to overallocation in the CUDACachingAllocation, this should better be fixed). The fix is when we allocate the tensor, instead of doing something like: ``` buf0 = randn_strided([2048, 2047], [2048, 1]) ``` we do some small overallocation ``` buf0 = randn_strided([2048, 2048], [2048, 1]).as_strided([2048, 2047], [2048, 1]) ``` cpp_wrapper needs special handling since memory allocation goes thru different code path to python wrapper. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145325 Approved by: https://github.com/desertfire, https://github.com/jansel ghstack dependencies: #140249
This commit is contained in:
committed by
PyTorch MergeBot
parent
f52901a0a7
commit
d3f196909d
@ -1886,6 +1886,13 @@ class PythonWrapperCodegen(CodeGen):
|
||||
)
|
||||
for e in buf.get_size()
|
||||
)
|
||||
allocation_size = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
e,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
for e in V.graph.get_allocation_size(buf)
|
||||
)
|
||||
stride = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
e,
|
||||
@ -1899,7 +1906,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
buf.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset})"
|
||||
value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset}, {allocation_size})"
|
||||
self.kernel_autotune_calls.writeline(f"{buf_name} = {value}")
|
||||
|
||||
if isinstance(raw_arg, ir.TMADescriptor):
|
||||
|
Reference in New Issue
Block a user