[graph partition] support graphsafe_run_with_rng_state (#150958)

Prior to this PR, `rng_state` is in `V.graph.graph_inputs` but not in read_writes of any IRNode. As a result, it is not identified as a partition inputs:
```python
def partition_0(args):
    primals_2, primals_1 = args
    ...
    buf0 = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype=torch.float32, device=device(type='cuda', index=1), pin_memory=False, rng_state=fwd_rng_state_0)
    # <----- access fwd_rng_state_0 but it's not an input
    ...

def call(self, args):
    primals_1, primals_2, fwd_rng_state_0 = args
    ...
    partition0_args = [primals_2, primals_1]
    (buf2, primals_2, primals_1) = self.partitions[0](partition0_args)
     # <---- fwd_rng_state_0 is graph_inputs but is not passed to partitions[0]
     ...
```

This PR fixes this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150958
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng
2025-04-12 03:17:06 +00:00
committed by PyTorch MergeBot
parent 397d37acc5
commit c1470d4dc4
4 changed files with 26 additions and 1 deletions

View File

@ -3390,6 +3390,10 @@ if HAS_CUDA:
def test_cudagraphs_aot_eager_compat_equal_device_one(self):
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1"))
@config.patch(graph_partition=True)
def test_graph_partition_cudagraphs_aot_eager_compat_equal(self):
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0"))
@requires_multigpu()
def test_multi_device(self):
def gn(x, y):

View File

@ -6716,6 +6716,18 @@ class FallbackKernel(ExternKernelAlloc):
for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
handle_aliasing_and_mutation(info, arg)
def get_read_writes(self) -> dependencies.ReadWrites:
read_writes = super().get_read_writes()
if self.op_overload is torch._prims.rng_prims.graphsafe_run_with_rng_state:
for arg in self.constant_args:
if isinstance(arg, GeneratorState):
read_writes = read_writes.with_read(
dependencies.StarDep(arg.get_name())
)
return read_writes
def codegen_unbacked_symbol_defs(self, wrapper) -> None: # type: ignore[no-untyped-def]
return wrapper.codegen_unbacked_symbol_defs_for_outputs(
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None)

View File

@ -94,7 +94,7 @@ def get_freeable_input_buf(
for node in nodes:
for dep in node.read_writes.reads:
if dep.name in graph_inputs and not dep.name.startswith(
("primals_", "arg")
("primals_", "arg", "fwd_rng_state", "bwd_rng_state")
):
dep_name_to_succ_nodes[dep.name].add(node)
dep_name_to_size[dep.name] = _dep_size_hint(dep)

View File

@ -671,6 +671,13 @@ class BaseSchedulerNode:
):
# todo: Calculate this - it's kinda annoying.
return {}
if (
isinstance(self, ExternKernelSchedulerNode)
and isinstance(self.node, ir.FallbackKernel)
and self.node.op_overload
is torch._prims.rng_prims.graphsafe_run_with_rng_state
):
return {}
def try_size_hint(s: sympy.Expr) -> int:
return V.graph.sizevars.size_hint(s, fallback=0)
@ -3908,6 +3915,8 @@ class Scheduler:
inp = V.graph.graph_inputs[name]
if isinstance(inp, ir.TorchBindObject):
V.graph.wrapper_code.codegen_free(inp)
elif isinstance(inp, ir.GeneratorState):
continue
else:
storage = inp.data
assert (