Fix AOT Graph capture not propagating non_blocking copy parameter to … (#136513)

…inductor codegen.

Fixes #136260

**Note**: this is my first code contribution to torch so please let me know if there's anything I need to fix/some other convention I should follow.

Regarding the bug, re-running the issue's reproduction code:
```
import torch

def fn(x):
    return x.to(device="cuda", non_blocking=True)

inp = torch.randn(3, 4)

torch.compile(fn)(inp)
```

We now have the non_blocking being passed on to codegen properly:

```
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_1 =====
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4]"):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4][4, 1]cpu"):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.404000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:114] [0/0] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0),subclass_metadata=None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] TRACED GRAPH
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  ===== Forward graph 0 =====
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]     def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0), True);  arg0_1 = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32);  device_put = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         return (convert_element_type,)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1134] [0/0] [__output_code] Output code written to: /tmp/torchinductor_niklasz/ha/chaai264g6ribfw3q2qhl6ayjtaqaavku5wivxtzw4nabgd6htsv.py
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] Output code:
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] # AOT ID: ['0_inference']
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import torch
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import math
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import random
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import os
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import tempfile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from math import inf, nan
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch import device, empty_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] aten = torch.ops.aten
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] _quantized = torch.ops._quantized
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile = AsyncCompile()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile.wait(globals())
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del async_compile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def call(args):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1, = args
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     args.clear()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     assert_size_stride(arg0_1, (3, 4), (4, 1))
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         torch.cuda.set_device(0)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0.copy_(arg0_1, True)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         del arg0_1
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return (buf0, )
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     fn = lambda: call([arg0_1])
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] if __name__ == "__main__":
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
```
See above line `buf0.copy_(arg0_1, True)`. Specific log setting used: `export TORCH_LOGS="graph_code,aot_graphs,output_code"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136513
Approved by: https://github.com/eellison
This commit is contained in:
niklasz
2024-10-01 00:32:44 +00:00
committed by PyTorch MergeBot
parent 19a4d68224
commit 3f457ee1f6
10 changed files with 42 additions and 24 deletions

View File

@ -12350,6 +12350,16 @@ if HAS_GPU and not TEST_WITH_ASAN:
print(p.key_averages().table(max_name_column_width=200))
def test_non_blocking_copy_codegen(self):
# Checks non_blocking arg is present in codegen
# (see https://github.com/pytorch/pytorch/issues/136260)
def fn(x):
return x.to(device=self.device, non_blocking=True)
inp = torch.randn(3, 4)
_, (code,) = run_and_get_code(torch.compile(fn), inp)
FileCheck().check("copy_").check_same("True").run(code)
class RNNTest(TestCase):
device_type = GPU_TYPE

View File

@ -2150,7 +2150,7 @@ def _to_copy(
if dtype is not None and device.type == "cpu":
x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
dtype_converted = True
x_tensor = torch._prims.device_put(x_tensor, device)
x_tensor = torch._prims.device_put(x_tensor, device, non_blocking)
if dtype is not None and not dtype_converted:
x_tensor = torch._prims.convert_element_type(x_tensor, dtype)

View File

@ -1800,13 +1800,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
else:
return final_tmp_name
def codegen_device_copy(self, src, dst):
def codegen_device_copy(self, src, dst, non_blocking: bool):
if config.abi_compatible:
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
)
else:
self.writeline(f"{dst}.copy_({src});")
self.writeline(f"{dst}.copy_({src}, {non_blocking});")
def codegen_multi_output(self, name, value):
# in the abi_compatible mode, outputs are retrieved by passing

View File

@ -420,17 +420,17 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
output_args,
)
def codegen_device_copy(self, src, dst):
def codegen_device_copy(self, src, dst, non_blocking: bool):
if config.abi_compatible:
# aoti_torch_tensor_copy_ takes AtenTensorHandle as input,
# while stack-allocation results in ArrayRefTensor
# so disable stack allocation here
self.allow_stack_allocation = False
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
)
else:
self.writeline(f"{dst}.copy_({src});")
self.writeline(f"{dst}.copy_({src}, {non_blocking});")
def codegen_reinterpret_view(
self, data, size_list, stride_list, offset, writer, dtype=None

View File

@ -1117,8 +1117,8 @@ class PythonWrapperCodegen(CodeGen):
f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})"
)
def codegen_device_copy(self, src, dst):
self.writeline(f"{dst}.copy_({src})")
def codegen_device_copy(self, src, dst, non_blocking: bool):
self.writeline(f"{dst}.copy_({src}, {non_blocking})")
def codegen_multi_output(self, name, value):
self.writeline(f"{self.declare}{name} = {value}{self.ending}")

View File

@ -711,7 +711,7 @@ def convert_element_type_noop(x, dtype: torch.dtype):
@register_noop_decomp(torch.ops.prims.device_put)
def device_put_noop(x, device):
def device_put_noop(x, device, non_blocking=True):
return x.device == decode_device(device)

View File

@ -5357,7 +5357,7 @@ class InplaceCopyFallback(ExternKernel):
def codegen(self, wrapper):
(dst, src, non_blocking) = self.codegen_args()
wrapper.codegen_device_copy(src, dst)
wrapper.codegen_device_copy(src, dst, non_blocking)
def should_allocate(self):
return False
@ -5592,7 +5592,7 @@ class IndexPutFallback(ExternKernel):
class DeviceCopy(ExternKernelOut):
@classmethod
def create(cls, x, device):
def create(cls, x, device, non_blocking):
if (
not x.is_extern()
and all(r in V.graph.constants for r in x.get_read_names())
@ -5604,6 +5604,7 @@ class DeviceCopy(ExternKernelOut):
V.graph.add_device_info(x.get_device())
developer_warning("DeviceCopy in input program")
constant_args = (non_blocking,)
return DeviceCopy(
FlexibleLayout(
device=device,
@ -5611,15 +5612,18 @@ class DeviceCopy(ExternKernelOut):
size=x.get_size(),
),
[cls.realize_input(x)],
constant_args,
)
def codegen(self, wrapper):
args = self.codegen_args()
assert len(args) == 1
assert len(args) == 2
if self.output_view:
wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference())
wrapper.codegen_device_copy(
args[0], self.output_view.codegen_reference(), args[1]
)
else:
wrapper.codegen_device_copy(args[0], self.codegen_reference())
wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
class DynamicScalar(ExternKernel):

View File

@ -687,16 +687,16 @@ def _view_dtype(x: TensorBox, dtype: torch.dtype):
return to_dtype_bitcast(x, dtype)
def to_device(x: TensorBox, device: torch.device, *, copy=False):
def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False):
device = decode_device(device)
if x.get_device() == device:
return clone(x) if copy else x
return TensorBox.create(ir.DeviceCopy.create(x, device))
return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking))
@register_lowering(prims.device_put, type_promotion_kind=None)
def _device_put(x: TensorBox, device: torch.device):
return to_device(x, device, copy=True)
def _device_put(x: TensorBox, device: torch.device, non_blocking=False):
return to_device(x, device, copy=True, non_blocking=non_blocking)
def register_pointwise(

View File

@ -35,6 +35,7 @@ from typing import (
Protocol,
Sequence,
Set,
Tuple,
TypeVar,
Union,
ValuesView,
@ -1304,7 +1305,7 @@ class DebugDirManager:
torch._dynamo.config.debug_dir_root = self.prev_debug_name
def run_and_get_code(fn, *args, **kwargs):
def run_and_get_code(fn, *args, **kwargs) -> Tuple[Any, List[str]]:
from .graph import GraphLowering
source_codes: List[str] = []

View File

@ -1945,16 +1945,19 @@ convert_element_type = _make_prim(
def _device_put_meta(
a: TensorLikeType, device: Union[str, torch.device]
a: TensorLikeType, device: Union[str, torch.device], non_blocking=False
) -> TensorLikeType:
assert isinstance(a, TensorLike)
assert isinstance(device, (str, torch.device))
assert isinstance(non_blocking, bool)
return TensorMeta(a, device=utils.canonicalize_device(device))
def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
return a.to(device)
def _device_put_aten(
a: Tensor, device: Union[str, torch.device], non_blocking=False
) -> Tensor:
return a.to(device, non_blocking=non_blocking)
_device_put_doc = """
@ -1962,7 +1965,7 @@ _device_put_doc = """
"""
device_put = _make_prim(
schema="device_put(Tensor a, Device device) -> Tensor",
schema="device_put(Tensor a, Device device, bool non_blocking=False) -> Tensor",
meta=_device_put_meta,
impl_aten=_device_put_aten,
return_type=RETURN_TYPE.NEW,