mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reflect back mutation if we clone misaligned tensors (#154442)
Fix for https://github.com/pytorch/pytorch/issues/152425 inductor specializes whether or not a tensor is 16-bit aligned on the first invocation. then, on subsequent invocations, if we inferred alignment but are passed a non-aligned tensor we clone the tensor. If we infer alignment, then run with unaligned, and mutate the input, we need to reflect back the mutation to the input. This pr adds back that mutation. We could have also been less aggressive about inferring alignment for mutated tensors, but that has a pretty perf hit.See the following benchmark: ``` import torch t = torch.rand(4096 * 4096, device="cuda", dtype=torch.float16) @torch.compile(dynamic=False) def foo(x): return x.add_(1) import triton print(triton.testing.do_bench(lambda: foo(t[:-1]))) torch._dynamo.reset() print(triton.testing.do_bench(lambda: foo(t[1:]))) ``` gives ``` 0.04063070610165596 0.07613472988113162 ``` So almost twice as slow for non-aligned tensors. Tensors changing alignment is a relatively rare case. In the future, we could considering a multi-kernel approach, or codegening a triton kernel that does most of the loads with aligned instructions, and a prologue/epilogue of un-alignment. But, it's yet to be seen this is a huge issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154442 Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
3c74a72ea0
commit
d6e29bf875
@ -1627,6 +1627,24 @@ class CudaReproTests(TestCase):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address
|
||||
|
||||
def test_mutated_aligned_tensor(self):
|
||||
t = torch.rand(4096, device="cuda", dtype=torch.float16)
|
||||
|
||||
def foo(x):
|
||||
return x.add_(1)
|
||||
|
||||
foo_c = torch.compile(dynamic=False)(foo)
|
||||
|
||||
t_orig = t.clone()
|
||||
|
||||
# First invocation, assume alignment, second invocation,
|
||||
# copy to alignment and then mutate after fn invocation
|
||||
self.assertEqual(foo_c(t[:-1]), foo(t_orig[:-1]))
|
||||
self.assertEqual(t, t_orig)
|
||||
|
||||
self.assertEqual(foo_c(t[1:]), foo(t_orig[1:]))
|
||||
self.assertEqual(t, t_orig)
|
||||
|
||||
def test_non_commutative_scan_op(self):
|
||||
from torch._higher_order_ops.associative_scan import associative_scan
|
||||
|
||||
|
@ -1713,7 +1713,7 @@ def cudagraphify_impl(
|
||||
graph.replay()
|
||||
return static_outputs
|
||||
|
||||
return align_inputs_from_check_idxs(run, check_input_idxs)
|
||||
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
|
||||
|
||||
|
||||
def compile_fx_aot(
|
||||
|
@ -385,7 +385,11 @@ def cudagraphify_impl(
|
||||
copy_misaligned_inputs(inputs, check_input_idxs)
|
||||
|
||||
fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
|
||||
fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs)
|
||||
# cudagraph will already clones input locally, no need to copy back
|
||||
mutated_input_idxs: OrderedSet[int] = OrderedSet()
|
||||
fn = align_inputs_from_check_idxs(
|
||||
fn, inputs_to_check=check_input_idxs, mutated_input_idxs=mutated_input_idxs
|
||||
)
|
||||
fn_cache[int_key] = fn
|
||||
|
||||
return out
|
||||
|
@ -325,6 +325,7 @@ def maybe_realign_inputs(
|
||||
ran_cudagraphs: BoxedBool,
|
||||
compiled_graph: CompiledFxGraph,
|
||||
inputs_to_check: Sequence[int],
|
||||
mutated_inputs_idxs: OrderedSet[int],
|
||||
) -> None:
|
||||
"""
|
||||
Realigns input strides from inputs_to_check if
|
||||
@ -335,7 +336,7 @@ def maybe_realign_inputs(
|
||||
if not ran_cudagraphs:
|
||||
assert compiled_graph.current_callable is not None
|
||||
new_callable = align_inputs_from_check_idxs(
|
||||
compiled_graph.current_callable, inputs_to_check
|
||||
compiled_graph.current_callable, inputs_to_check, mutated_inputs_idxs
|
||||
)
|
||||
if new_callable is not compiled_graph.current_callable:
|
||||
compiled_graph.current_callable = new_callable
|
||||
@ -654,6 +655,7 @@ class CompiledFxGraph(OutputCode):
|
||||
cudagraphs,
|
||||
self,
|
||||
inputs_to_check,
|
||||
self.mutated_input_idxs,
|
||||
)
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
|
@ -2652,13 +2652,23 @@ def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
|
||||
def align_inputs_from_check_idxs(
|
||||
model: Callable[[list[InputType]], _T],
|
||||
inputs_to_check: Sequence[int],
|
||||
mutated_input_idxs: OrderedSet[int],
|
||||
) -> Callable[[list[InputType]], _T]:
|
||||
if len(inputs_to_check) == 0:
|
||||
return model
|
||||
|
||||
def run(new_inputs: list[InputType]) -> Any:
|
||||
copy_misaligned_inputs(new_inputs, inputs_to_check)
|
||||
return model(new_inputs)
|
||||
old_tensors, new_tensors = copy_misaligned_inputs(
|
||||
new_inputs, inputs_to_check, mutated_input_idxs
|
||||
)
|
||||
out = model(new_inputs)
|
||||
|
||||
# If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
|
||||
# original tensor.
|
||||
if len(old_tensors):
|
||||
torch._foreach_copy_(old_tensors, new_tensors)
|
||||
|
||||
return out
|
||||
|
||||
return run
|
||||
|
||||
@ -2676,14 +2686,32 @@ def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def copy_misaligned_inputs(
|
||||
new_inputs: list[InputType], check_inputs_idxs: Sequence[int]
|
||||
) -> None:
|
||||
new_inputs: list[InputType],
|
||||
check_inputs_idxs: Sequence[int],
|
||||
return_pair_idxs: Optional[OrderedSet[int]] = None,
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every
|
||||
cloned tensor which is in `return_pair_idxs`.
|
||||
"""
|
||||
|
||||
old_tensors: list[torch.Tensor] = []
|
||||
new_tensors: list[torch.Tensor] = []
|
||||
|
||||
# hoist above loop because this is on the hot path
|
||||
ret_pair_defined = return_pair_idxs is not None
|
||||
for i in check_inputs_idxs:
|
||||
_inp = new_inputs[i]
|
||||
assert isinstance(_inp, torch.Tensor)
|
||||
if _inp.data_ptr() % ALIGNMENT:
|
||||
new_inputs[i] = clone_preserve_strides(_inp)
|
||||
|
||||
if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator]
|
||||
old_tensors.append(_inp)
|
||||
new_tensors.append(new_inputs[i]) # type: ignore[arg-type]
|
||||
|
||||
return old_tensors, new_tensors
|
||||
|
||||
|
||||
def remove_unaligned_input_idxs(
|
||||
inputs: Sequence[InputType],
|
||||
|
Reference in New Issue
Block a user