mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[graph partition] support graphsafe_run_with_rng_state (#150958)
Prior to this PR, `rng_state` is in `V.graph.graph_inputs` but not in read_writes of any IRNode. As a result, it is not identified as a partition inputs: ```python def partition_0(args): primals_2, primals_1 = args ... buf0 = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype=torch.float32, device=device(type='cuda', index=1), pin_memory=False, rng_state=fwd_rng_state_0) # <----- access fwd_rng_state_0 but it's not an input ... def call(self, args): primals_1, primals_2, fwd_rng_state_0 = args ... partition0_args = [primals_2, primals_1] (buf2, primals_2, primals_1) = self.partitions[0](partition0_args) # <---- fwd_rng_state_0 is graph_inputs but is not passed to partitions[0] ... ``` This PR fixes this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150958 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
397d37acc5
commit
c1470d4dc4
@ -3390,6 +3390,10 @@ if HAS_CUDA:
|
||||
def test_cudagraphs_aot_eager_compat_equal_device_one(self):
|
||||
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1"))
|
||||
|
||||
@config.patch(graph_partition=True)
|
||||
def test_graph_partition_cudagraphs_aot_eager_compat_equal(self):
|
||||
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0"))
|
||||
|
||||
@requires_multigpu()
|
||||
def test_multi_device(self):
|
||||
def gn(x, y):
|
||||
|
@ -6716,6 +6716,18 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
|
||||
handle_aliasing_and_mutation(info, arg)
|
||||
|
||||
def get_read_writes(self) -> dependencies.ReadWrites:
|
||||
read_writes = super().get_read_writes()
|
||||
|
||||
if self.op_overload is torch._prims.rng_prims.graphsafe_run_with_rng_state:
|
||||
for arg in self.constant_args:
|
||||
if isinstance(arg, GeneratorState):
|
||||
read_writes = read_writes.with_read(
|
||||
dependencies.StarDep(arg.get_name())
|
||||
)
|
||||
|
||||
return read_writes
|
||||
|
||||
def codegen_unbacked_symbol_defs(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
return wrapper.codegen_unbacked_symbol_defs_for_outputs(
|
||||
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None)
|
||||
|
@ -94,7 +94,7 @@ def get_freeable_input_buf(
|
||||
for node in nodes:
|
||||
for dep in node.read_writes.reads:
|
||||
if dep.name in graph_inputs and not dep.name.startswith(
|
||||
("primals_", "arg")
|
||||
("primals_", "arg", "fwd_rng_state", "bwd_rng_state")
|
||||
):
|
||||
dep_name_to_succ_nodes[dep.name].add(node)
|
||||
dep_name_to_size[dep.name] = _dep_size_hint(dep)
|
||||
|
@ -671,6 +671,13 @@ class BaseSchedulerNode:
|
||||
):
|
||||
# todo: Calculate this - it's kinda annoying.
|
||||
return {}
|
||||
if (
|
||||
isinstance(self, ExternKernelSchedulerNode)
|
||||
and isinstance(self.node, ir.FallbackKernel)
|
||||
and self.node.op_overload
|
||||
is torch._prims.rng_prims.graphsafe_run_with_rng_state
|
||||
):
|
||||
return {}
|
||||
|
||||
def try_size_hint(s: sympy.Expr) -> int:
|
||||
return V.graph.sizevars.size_hint(s, fallback=0)
|
||||
@ -3908,6 +3915,8 @@ class Scheduler:
|
||||
inp = V.graph.graph_inputs[name]
|
||||
if isinstance(inp, ir.TorchBindObject):
|
||||
V.graph.wrapper_code.codegen_free(inp)
|
||||
elif isinstance(inp, ir.GeneratorState):
|
||||
continue
|
||||
else:
|
||||
storage = inp.data
|
||||
assert (
|
||||
|
Reference in New Issue
Block a user