Compare commits

...

8 Commits

Author SHA1 Message Date
be420bd67b move cpu scalar tensor to cuda 2025-05-27 15:40:40 -07:00
b103bf233c nit 2025-05-27 11:34:27 -07:00
a42a09b42b Merge branch 'main' into bf/partition-cpu-read 2025-05-27 09:53:57 -07:00
640db06cce Merge branch 'main' into bf/partition-cpu-read 2025-05-21 21:46:44 -07:00
029c49b1ad nit 2025-05-21 17:34:09 -07:00
b1cf67b554 nit 2025-05-21 15:46:06 -07:00
b01c4e7160 Merge branch 'main' into bf/partition-cpu-read 2025-05-21 15:06:36 -07:00
53fc523368 partition on a node that reads/writes to cpu tensors 2025-05-20 21:23:42 -07:00
5 changed files with 129 additions and 2 deletions

View File

@ -2923,7 +2923,16 @@ main()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
# Compiled autograd lifts custom autograd.Function bwd instead of tracing it.
# Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
if inductor_config.graph_partition:
# instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops
# and cudagraphify the remaining computation. So there is no cudagraph skip.
expected_cudagraph_skips = 0
else:
expected_cudagraph_skips = 1
self.assertEqual(
counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips
)
@scoped_load_inline
@unittest.skipIf(not HAS_CUDA, "requires cuda")

View File

@ -13067,6 +13067,27 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
foo = torch.compile(foo)
foo()
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_read_cpu_tensor_graph_input(self):
def f(x, y, cpu_scalar_tensor):
z = x + y
z = z + cpu_scalar_tensor
return z
f = torch.compile(f)
x, y, cpu_scalar_tensor = (
torch.randn(2, 2, device=self.device),
torch.randn(2, 2, device=self.device),
torch.tensor(1, device="cpu"),
)
_, code = run_and_get_code(f, x, y, cpu_scalar_tensor)
if not config.cpp_wrapper:
# f has only 1 scheduler node which reads cpu_scalar_tensor.
# So we don't have any partitions.
FileCheck().check("runner = Runner(partitions=[])").run(code[0])
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_mutation_real_name(self):
def f(x, y, z, other):

View File

@ -102,7 +102,11 @@ from .debug import DebugContext
from .decomposition import select_decomp_table
from .exc import InductorError
from .fx_passes.joint_graph import joint_graph_passes
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
from .fx_passes.post_grad import (
move_cpu_scalar_tensor_to_cuda,
post_grad_passes,
view_to_reshape,
)
from .fx_passes.pre_grad import pre_grad_passes
from .graph import GraphLowering
from .ir import get_device_type, IRNode
@ -1132,6 +1136,8 @@ class _InProcessFxCompile(FxCompile):
shape_env = shape_env_from_inputs(example_inputs)
move_cpu_scalar_tensor_to_cuda(gm)
# Convert view to reshape in the graph. This is necessary primarily for
# layout optimization. Do it unconditionally for uniformity.
#

View File

@ -5,6 +5,7 @@ import itertools
import logging
import operator
from collections import Counter, defaultdict
from collections.abc import Sequence
from typing import Any, Callable, Optional, TypeVar, Union
from typing_extensions import ParamSpec
@ -1339,6 +1340,57 @@ def view_to_reshape(gm):
nd.target = torch.ops.aten.reshape.default
def find_gpu_device(args: Sequence[Any]) -> Optional[torch.device]:
for arg in args:
if not isinstance(arg, torch.fx.Node):
continue
val = arg.meta["val"]
if isinstance(val, torch.Tensor) and is_gpu(val.device.type):
return val.device
return None
def move_cpu_scalar_tensor_to_cuda(gm: torch.fx.GraphModule):
"""
TODO: Doc
"""
graph = gm.graph
cpu_to_gpu: dict[torch.fx.Node, torch.fx.Node] = {}
def get_gpu_arg(arg: torch.fx.Node):
if arg in cpu_to_gpu:
arg_gpu = cpu_to_gpu[arg]
else:
with graph.inserting_after(arg):
arg_gpu = graph.call_function(
torch.ops.prims.device_put.default, (arg, gpu_device)
)
cpu_to_gpu[arg] = arg_gpu
return arg_gpu
for node in graph.nodes:
if node.op == "call_function":
gpu_device = find_gpu_device(node.args)
if gpu_device is None:
continue
new_args = []
for arg in node.args:
if (
isinstance(arg, torch.fx.Node)
and (val := arg.meta["val"]) is not None
and isinstance(val, torch.Tensor)
and val.is_cpu
):
new_args.append(get_gpu_arg(arg))
else:
new_args.append(arg)
node.args = tuple(new_args)
gm.recompile()
def should_prefer_unfused_addmm(match):
inp = match.kwargs["inp"]
if not is_gpu(inp.meta["val"].device.type):

View File

@ -4076,12 +4076,51 @@ class Scheduler:
def should_partition(self, node: BaseSchedulerNode) -> bool:
"""Return True if we should partition the inductor graph on this node"""
def is_non_gpu_tensor(buf: Any) -> bool:
return isinstance(buf, torch.Tensor) and not is_gpu(buf.device.type)
def is_non_gpu_tensor_box(buf: Any) -> bool:
if (
isinstance(buf, ir.TensorBox)
and (device := buf.get_device())
and not is_gpu(device.type)
):
return True
return False
def read_write_non_gpu_data(node: BaseSchedulerNode) -> bool:
for dep in node.read_writes.reads | node.read_writes.writes:
name = dep.name
if (inp := V.graph.graph_inputs.get(name, None)) is not None:
if is_non_gpu_tensor_box(inp):
return True
elif (buf := self.name_to_buf.get(name, None)) is not None:
if is_non_gpu_tensor_box(buf):
return True
elif (tensor := V.graph.constants.get(name, None)) is not None:
if is_non_gpu_tensor(tensor):
return True
else:
assert name in V.graph.torchbind_constants, (
f"Expected all dependences to be either graph_inputs, "
f"name_to_buf, constants, or torchbind_constants, "
f"but found: {dep}"
)
return False
if isinstance(node, FusedSchedulerNode):
return any(self.should_partition(snode) for snode in node.snodes)
if not node.is_gpu():
return True
if read_write_non_gpu_data(node):
return True
if node.node is None:
return True