mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix redundant H2D/D2H memcpy in cpp_wrapper by creating scalar tensors on CPU (#160584)
Fixes #160520 Summary: When running Inductor with cpp_wrapper under a DeviceContext, non-tensor arguments were being wrapped with torch.tensor(arg) without specifying the device. creating the tensor on the current active device (like CUDA), and later fetching it back to CPU via .item(), causing unnecessary host-device-host memory transfers. PR fixes issue by explicitly creating scalar tensors on the CPU: ``` input_tensors = [ arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args ] ``` impact: inductor, codegen Pull Request resolved: https://github.com/pytorch/pytorch/pull/160584 Approved by: https://github.com/benjaminglass1, https://github.com/desertfire, https://github.com/mlazos, https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
8c98aee436
commit
65ddd91421
@ -62,6 +62,19 @@ class TestGpuWrapper(InductorTestCase):
|
||||
)(test_fn)
|
||||
comp()
|
||||
|
||||
def test_non_tensor_args_wrapped_on_cpu(self):
|
||||
if not RUN_GPU:
|
||||
self.skipTest("GPU not available")
|
||||
|
||||
def test_fn(x, s):
|
||||
return (x + s).sum()
|
||||
|
||||
compiled = torch.compile(options={"cpp_wrapper": True})(test_fn)
|
||||
x = torch.randn(4, device=self.device)
|
||||
with torch.utils._device.DeviceContext(self.device):
|
||||
_, code = test_torchinductor.run_and_get_cpp_code(compiled, x, 3)
|
||||
self.assertIn("torch.tensor(arg, device='cpu')", code)
|
||||
|
||||
|
||||
class DynamicShapesGpuWrapperGpuTests(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
@ -1169,7 +1169,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
"""
|
||||
)
|
||||
|
||||
wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]"
|
||||
wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]"
|
||||
if V.graph.constants:
|
||||
# Append constants to the input args for cpp wrapper.
|
||||
# Python wrapper directly gets the value inside the wrapper call
|
||||
|
Reference in New Issue
Block a user