Fix redundant H2D/D2H memcpy in cpp_wrapper by creating scalar tensors on CPU (#160584)

Fixes #160520

Summary:
When running Inductor with cpp_wrapper under a DeviceContext, non-tensor arguments were being wrapped with torch.tensor(arg) without specifying the device.

creating the tensor on the current active device (like CUDA), and later fetching it back to CPU via .item(), causing unnecessary host-device-host memory transfers.

PR fixes issue by explicitly creating scalar tensors on the CPU:

```
input_tensors = [
    arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu')
    for arg in args
]
```

impact: inductor, codegen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160584
Approved by: https://github.com/benjaminglass1, https://github.com/desertfire, https://github.com/mlazos, https://github.com/jeffdaily
This commit is contained in:
Raman Kumar
2025-09-24 23:40:34 +00:00
committed by PyTorch MergeBot
parent 8c98aee436
commit 65ddd91421
2 changed files with 14 additions and 1 deletions

View File

@ -62,6 +62,19 @@ class TestGpuWrapper(InductorTestCase):
)(test_fn)
comp()
def test_non_tensor_args_wrapped_on_cpu(self):
if not RUN_GPU:
self.skipTest("GPU not available")
def test_fn(x, s):
return (x + s).sum()
compiled = torch.compile(options={"cpp_wrapper": True})(test_fn)
x = torch.randn(4, device=self.device)
with torch.utils._device.DeviceContext(self.device):
_, code = test_torchinductor.run_and_get_cpp_code(compiled, x, 3)
self.assertIn("torch.tensor(arg, device='cpu')", code)
class DynamicShapesGpuWrapperGpuTests(InductorTestCase):
device = GPU_TYPE

View File

@ -1169,7 +1169,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
"""
)
wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]"
wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]"
if V.graph.constants:
# Append constants to the input args for cpp wrapper.
# Python wrapper directly gets the value inside the wrapper call