mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix AOT Graph capture not propagating non_blocking copy parameter to … (#136513)
…inductor codegen. Fixes #136260 **Note**: this is my first code contribution to torch so please let me know if there's anything I need to fix/some other convention I should follow. Regarding the bug, re-running the issue's reproduction code: ``` import torch def fn(x): return x.to(device="cuda", non_blocking=True) inp = torch.randn(3, 4) torch.compile(fn)(inp) ``` We now have the non_blocking being passed on to codegen properly: ``` V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] TRACED GRAPH V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] ===== pre insert_deferred_runtime_asserts __compiled_fn_1 ===== V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] <eval_with_key>.0 class GraphModule(torch.nn.Module): V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] def forward(self, L_x_: "f32[3, 4]"): V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] l_x_ = L_x_ V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True) V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True); l_x_ = None V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] return (to,) V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] TRACED GRAPH V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] ===== __compiled_fn_1 ===== V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] def forward(self, L_x_: "f32[3, 4][4, 1]cpu"): V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] l_x_ = L_x_ V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True) V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True); l_x_ = None V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] return (to,) V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] V0922 20:33:25.404000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:114] [0/0] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0),subclass_metadata=None I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] TRACED GRAPH I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] ===== Forward graph 0 ===== I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"): I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True) I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0), True); arg0_1 = None I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32); device_put = None I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] return (convert_element_type,) I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1134] [0/0] [__output_code] Output code written to: /tmp/torchinductor_niklasz/ha/chaai264g6ribfw3q2qhl6ayjtaqaavku5wivxtzw4nabgd6htsv.py V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] Output code: V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] # AOT ID: ['0_inference'] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import torch V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import math V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import random V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import os V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import tempfile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from math import inf, nan V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import maybe_profile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch import device, empty_strided V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] aten = torch.ops.aten V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] inductor_ops = torch.ops.inductor V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] _quantized = torch.ops._quantized V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile = AsyncCompile() V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile.wait(globals()) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del async_compile V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def call(args): V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] arg0_1, = args V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] args.clear() V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride(arg0_1, (3, 4), (4, 1)) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] with torch.cuda._DeviceGuard(0): V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] torch.cuda.set_device(0) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] buf0.copy_(arg0_1, True) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del arg0_1 V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] return (buf0, ) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10): V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._dynamo.testing import rand_strided V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import print_performance V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] fn = lambda: call([arg0_1]) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] return print_performance(fn, times=times, repeat=repeat) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] if __name__ == "__main__": V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.wrapper_benchmark import compiled_module_main V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] compiled_module_main('None', benchmark_compiled_module) V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] ``` See above line `buf0.copy_(arg0_1, True)`. Specific log setting used: `export TORCH_LOGS="graph_code,aot_graphs,output_code"` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136513 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
19a4d68224
commit
3f457ee1f6
@ -12350,6 +12350,16 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
|
||||
print(p.key_averages().table(max_name_column_width=200))
|
||||
|
||||
def test_non_blocking_copy_codegen(self):
|
||||
# Checks non_blocking arg is present in codegen
|
||||
# (see https://github.com/pytorch/pytorch/issues/136260)
|
||||
def fn(x):
|
||||
return x.to(device=self.device, non_blocking=True)
|
||||
|
||||
inp = torch.randn(3, 4)
|
||||
_, (code,) = run_and_get_code(torch.compile(fn), inp)
|
||||
FileCheck().check("copy_").check_same("True").run(code)
|
||||
|
||||
class RNNTest(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
|
@ -2150,7 +2150,7 @@ def _to_copy(
|
||||
if dtype is not None and device.type == "cpu":
|
||||
x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
|
||||
dtype_converted = True
|
||||
x_tensor = torch._prims.device_put(x_tensor, device)
|
||||
x_tensor = torch._prims.device_put(x_tensor, device, non_blocking)
|
||||
|
||||
if dtype is not None and not dtype_converted:
|
||||
x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
|
||||
|
@ -1800,13 +1800,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
else:
|
||||
return final_tmp_name
|
||||
|
||||
def codegen_device_copy(self, src, dst):
|
||||
def codegen_device_copy(self, src, dst, non_blocking: bool):
|
||||
if config.abi_compatible:
|
||||
self.writeline(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
|
||||
)
|
||||
else:
|
||||
self.writeline(f"{dst}.copy_({src});")
|
||||
self.writeline(f"{dst}.copy_({src}, {non_blocking});")
|
||||
|
||||
def codegen_multi_output(self, name, value):
|
||||
# in the abi_compatible mode, outputs are retrieved by passing
|
||||
|
@ -420,17 +420,17 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||
output_args,
|
||||
)
|
||||
|
||||
def codegen_device_copy(self, src, dst):
|
||||
def codegen_device_copy(self, src, dst, non_blocking: bool):
|
||||
if config.abi_compatible:
|
||||
# aoti_torch_tensor_copy_ takes AtenTensorHandle as input,
|
||||
# while stack-allocation results in ArrayRefTensor
|
||||
# so disable stack allocation here
|
||||
self.allow_stack_allocation = False
|
||||
self.writeline(
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
|
||||
)
|
||||
else:
|
||||
self.writeline(f"{dst}.copy_({src});")
|
||||
self.writeline(f"{dst}.copy_({src}, {non_blocking});")
|
||||
|
||||
def codegen_reinterpret_view(
|
||||
self, data, size_list, stride_list, offset, writer, dtype=None
|
||||
|
@ -1117,8 +1117,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})"
|
||||
)
|
||||
|
||||
def codegen_device_copy(self, src, dst):
|
||||
self.writeline(f"{dst}.copy_({src})")
|
||||
def codegen_device_copy(self, src, dst, non_blocking: bool):
|
||||
self.writeline(f"{dst}.copy_({src}, {non_blocking})")
|
||||
|
||||
def codegen_multi_output(self, name, value):
|
||||
self.writeline(f"{self.declare}{name} = {value}{self.ending}")
|
||||
|
@ -711,7 +711,7 @@ def convert_element_type_noop(x, dtype: torch.dtype):
|
||||
|
||||
|
||||
@register_noop_decomp(torch.ops.prims.device_put)
|
||||
def device_put_noop(x, device):
|
||||
def device_put_noop(x, device, non_blocking=True):
|
||||
return x.device == decode_device(device)
|
||||
|
||||
|
||||
|
@ -5357,7 +5357,7 @@ class InplaceCopyFallback(ExternKernel):
|
||||
|
||||
def codegen(self, wrapper):
|
||||
(dst, src, non_blocking) = self.codegen_args()
|
||||
wrapper.codegen_device_copy(src, dst)
|
||||
wrapper.codegen_device_copy(src, dst, non_blocking)
|
||||
|
||||
def should_allocate(self):
|
||||
return False
|
||||
@ -5592,7 +5592,7 @@ class IndexPutFallback(ExternKernel):
|
||||
|
||||
class DeviceCopy(ExternKernelOut):
|
||||
@classmethod
|
||||
def create(cls, x, device):
|
||||
def create(cls, x, device, non_blocking):
|
||||
if (
|
||||
not x.is_extern()
|
||||
and all(r in V.graph.constants for r in x.get_read_names())
|
||||
@ -5604,6 +5604,7 @@ class DeviceCopy(ExternKernelOut):
|
||||
V.graph.add_device_info(x.get_device())
|
||||
|
||||
developer_warning("DeviceCopy in input program")
|
||||
constant_args = (non_blocking,)
|
||||
return DeviceCopy(
|
||||
FlexibleLayout(
|
||||
device=device,
|
||||
@ -5611,15 +5612,18 @@ class DeviceCopy(ExternKernelOut):
|
||||
size=x.get_size(),
|
||||
),
|
||||
[cls.realize_input(x)],
|
||||
constant_args,
|
||||
)
|
||||
|
||||
def codegen(self, wrapper):
|
||||
args = self.codegen_args()
|
||||
assert len(args) == 1
|
||||
assert len(args) == 2
|
||||
if self.output_view:
|
||||
wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference())
|
||||
wrapper.codegen_device_copy(
|
||||
args[0], self.output_view.codegen_reference(), args[1]
|
||||
)
|
||||
else:
|
||||
wrapper.codegen_device_copy(args[0], self.codegen_reference())
|
||||
wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
|
||||
|
||||
|
||||
class DynamicScalar(ExternKernel):
|
||||
|
@ -687,16 +687,16 @@ def _view_dtype(x: TensorBox, dtype: torch.dtype):
|
||||
return to_dtype_bitcast(x, dtype)
|
||||
|
||||
|
||||
def to_device(x: TensorBox, device: torch.device, *, copy=False):
|
||||
def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False):
|
||||
device = decode_device(device)
|
||||
if x.get_device() == device:
|
||||
return clone(x) if copy else x
|
||||
return TensorBox.create(ir.DeviceCopy.create(x, device))
|
||||
return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking))
|
||||
|
||||
|
||||
@register_lowering(prims.device_put, type_promotion_kind=None)
|
||||
def _device_put(x: TensorBox, device: torch.device):
|
||||
return to_device(x, device, copy=True)
|
||||
def _device_put(x: TensorBox, device: torch.device, non_blocking=False):
|
||||
return to_device(x, device, copy=True, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def register_pointwise(
|
||||
|
@ -35,6 +35,7 @@ from typing import (
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
ValuesView,
|
||||
@ -1304,7 +1305,7 @@ class DebugDirManager:
|
||||
torch._dynamo.config.debug_dir_root = self.prev_debug_name
|
||||
|
||||
|
||||
def run_and_get_code(fn, *args, **kwargs):
|
||||
def run_and_get_code(fn, *args, **kwargs) -> Tuple[Any, List[str]]:
|
||||
from .graph import GraphLowering
|
||||
|
||||
source_codes: List[str] = []
|
||||
|
@ -1945,16 +1945,19 @@ convert_element_type = _make_prim(
|
||||
|
||||
|
||||
def _device_put_meta(
|
||||
a: TensorLikeType, device: Union[str, torch.device]
|
||||
a: TensorLikeType, device: Union[str, torch.device], non_blocking=False
|
||||
) -> TensorLikeType:
|
||||
assert isinstance(a, TensorLike)
|
||||
assert isinstance(device, (str, torch.device))
|
||||
assert isinstance(non_blocking, bool)
|
||||
|
||||
return TensorMeta(a, device=utils.canonicalize_device(device))
|
||||
|
||||
|
||||
def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
|
||||
return a.to(device)
|
||||
def _device_put_aten(
|
||||
a: Tensor, device: Union[str, torch.device], non_blocking=False
|
||||
) -> Tensor:
|
||||
return a.to(device, non_blocking=non_blocking)
|
||||
|
||||
|
||||
_device_put_doc = """
|
||||
@ -1962,7 +1965,7 @@ _device_put_doc = """
|
||||
"""
|
||||
|
||||
device_put = _make_prim(
|
||||
schema="device_put(Tensor a, Device device) -> Tensor",
|
||||
schema="device_put(Tensor a, Device device, bool non_blocking=False) -> Tensor",
|
||||
meta=_device_put_meta,
|
||||
impl_aten=_device_put_aten,
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
|
Reference in New Issue
Block a user