Reflect back mutation if we clone misaligned tensors (#154442)

Fix for https://github.com/pytorch/pytorch/issues/152425

inductor specializes whether or not a tensor is 16-bit aligned on the first invocation. then, on subsequent invocations, if we inferred alignment but are passed a non-aligned tensor we clone the tensor.

If we infer alignment, then run with unaligned, and mutate the input, we need to reflect back the mutation to the input. This pr adds back that mutation.

We could have also been less aggressive about inferring alignment for mutated tensors, but that has a pretty perf hit.See the following benchmark:
```
import torch

t = torch.rand(4096 * 4096, device="cuda", dtype=torch.float16)

@torch.compile(dynamic=False)
def foo(x):
    return x.add_(1)

import triton

print(triton.testing.do_bench(lambda: foo(t[:-1])))
torch._dynamo.reset()
print(triton.testing.do_bench(lambda: foo(t[1:])))
```
gives
```
0.04063070610165596
0.07613472988113162
```
So almost twice as slow for non-aligned tensors. Tensors changing alignment is a relatively rare case.

In the future, we could considering a multi-kernel approach, or codegening a triton kernel that does most of the loads with aligned instructions, and a prologue/epilogue of un-alignment. But, it's yet to be seen this is a huge issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154442
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
This commit is contained in:
eellison
2025-05-28 15:30:29 +00:00
committed by PyTorch MergeBot
parent 3c74a72ea0
commit d6e29bf875
5 changed files with 59 additions and 7 deletions

View File

@ -1627,6 +1627,24 @@ class CudaReproTests(TestCase):
fn(*args)
torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address
def test_mutated_aligned_tensor(self):
t = torch.rand(4096, device="cuda", dtype=torch.float16)
def foo(x):
return x.add_(1)
foo_c = torch.compile(dynamic=False)(foo)
t_orig = t.clone()
# First invocation, assume alignment, second invocation,
# copy to alignment and then mutate after fn invocation
self.assertEqual(foo_c(t[:-1]), foo(t_orig[:-1]))
self.assertEqual(t, t_orig)
self.assertEqual(foo_c(t[1:]), foo(t_orig[1:]))
self.assertEqual(t, t_orig)
def test_non_commutative_scan_op(self):
from torch._higher_order_ops.associative_scan import associative_scan

View File

@ -1713,7 +1713,7 @@ def cudagraphify_impl(
graph.replay()
return static_outputs
return align_inputs_from_check_idxs(run, check_input_idxs)
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
def compile_fx_aot(

View File

@ -385,7 +385,11 @@ def cudagraphify_impl(
copy_misaligned_inputs(inputs, check_input_idxs)
fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs)
# cudagraph will already clones input locally, no need to copy back
mutated_input_idxs: OrderedSet[int] = OrderedSet()
fn = align_inputs_from_check_idxs(
fn, inputs_to_check=check_input_idxs, mutated_input_idxs=mutated_input_idxs
)
fn_cache[int_key] = fn
return out

View File

@ -325,6 +325,7 @@ def maybe_realign_inputs(
ran_cudagraphs: BoxedBool,
compiled_graph: CompiledFxGraph,
inputs_to_check: Sequence[int],
mutated_inputs_idxs: OrderedSet[int],
) -> None:
"""
Realigns input strides from inputs_to_check if
@ -335,7 +336,7 @@ def maybe_realign_inputs(
if not ran_cudagraphs:
assert compiled_graph.current_callable is not None
new_callable = align_inputs_from_check_idxs(
compiled_graph.current_callable, inputs_to_check
compiled_graph.current_callable, inputs_to_check, mutated_inputs_idxs
)
if new_callable is not compiled_graph.current_callable:
compiled_graph.current_callable = new_callable
@ -654,6 +655,7 @@ class CompiledFxGraph(OutputCode):
cudagraphs,
self,
inputs_to_check,
self.mutated_input_idxs,
)
def set_triton_bundle(self, triton_bundle: Any) -> None:

View File

@ -2652,13 +2652,23 @@ def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
def align_inputs_from_check_idxs(
model: Callable[[list[InputType]], _T],
inputs_to_check: Sequence[int],
mutated_input_idxs: OrderedSet[int],
) -> Callable[[list[InputType]], _T]:
if len(inputs_to_check) == 0:
return model
def run(new_inputs: list[InputType]) -> Any:
copy_misaligned_inputs(new_inputs, inputs_to_check)
return model(new_inputs)
old_tensors, new_tensors = copy_misaligned_inputs(
new_inputs, inputs_to_check, mutated_input_idxs
)
out = model(new_inputs)
# If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
# original tensor.
if len(old_tensors):
torch._foreach_copy_(old_tensors, new_tensors)
return out
return run
@ -2676,14 +2686,32 @@ def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
def copy_misaligned_inputs(
new_inputs: list[InputType], check_inputs_idxs: Sequence[int]
) -> None:
new_inputs: list[InputType],
check_inputs_idxs: Sequence[int],
return_pair_idxs: Optional[OrderedSet[int]] = None,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every
cloned tensor which is in `return_pair_idxs`.
"""
old_tensors: list[torch.Tensor] = []
new_tensors: list[torch.Tensor] = []
# hoist above loop because this is on the hot path
ret_pair_defined = return_pair_idxs is not None
for i in check_inputs_idxs:
_inp = new_inputs[i]
assert isinstance(_inp, torch.Tensor)
if _inp.data_ptr() % ALIGNMENT:
new_inputs[i] = clone_preserve_strides(_inp)
if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator]
old_tensors.append(_inp)
new_tensors.append(new_inputs[i]) # type: ignore[arg-type]
return old_tensors, new_tensors
def remove_unaligned_input_idxs(
inputs: Sequence[InputType],