[fx] split_module subgraph should always have an output node (#139275)

Fixes https://github.com/pytorch/pytorch/issues/138207

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139275
Approved by: https://github.com/ezyang
This commit is contained in:
kshitij12345
2024-10-31 04:53:16 +00:00
committed by PyTorch MergeBot
parent e3e3ab805b
commit 0cf4cc3d5f
2 changed files with 21 additions and 0 deletions

View File

@ -18,6 +18,7 @@ import torch.fx.experimental.optimization as optimization
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.experimental import merge_matmul
from torch.fx.experimental.accelerator_partitioner import Partitioner
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
from torch.fx.experimental.partitioner_utils import (
Device,
@ -819,6 +820,23 @@ terrible spacing
split(x), traced(x)
)
def test_split_module_return_node(self):
def foo(x):
x.add_(1)
gm = make_fx(foo, tracing_mode="fake")(torch.randn(3,))
def cb(_):
return 1
sp_gm = split_module(gm, None, cb)
submod_gm = sp_gm.submod_1
for node in submod_gm.graph.nodes:
if node.op == "output":
break
else:
raise RuntimeError("Expected the subgraph to have an output node.")
def test_split_module_kwargs_expansion(self):
class ModuleWithKwargsExpansion(torch.nn.Module):